Git Product home page Git Product logo

Comments (39)

GorkaAbad avatar GorkaAbad commented on August 17, 2024 9

Hi,
here is a working MNIST example using CUDA. Reusing some code from above. It may be verbose and far from optimal.

I get around 73% test accuracy in about 1 minute. Playing with the network size may improve the performance.

import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt


def train_acc():
    # model for some reason is on cpu only here, something about KAN's implementation
    try:
        arg = (
            torch.argmax(model(dataset["train_input"]), dim=1) == dataset["train_label"]
        )

    except:
        arg = torch.argmax(model(dataset["train_input"].to("cpu")), dim=1) == dataset[
            "train_label"
        ].to("cpu")
    return torch.mean(arg.float())


def test_acc():
    try:
        arg = torch.argmax(model(dataset["test_input"]), dim=1) == dataset["test_label"]
    except:
        arg = torch.argmax(model(dataset["test_input"].to("cpu")), dim=1) == dataset[
            "test_label"
        ].to("cpu")

    return torch.mean(arg.float())


train_data = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=None
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")
valid_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

X_train = []
y_train = []

for pil_img, label in train_data:
    if label in valid_labels:
        x = np.array(pil_img)
        x = cv2.resize(x, (7, 7))
        X_train.append(x.astype(float))
        y_train.append(label)

X_train = np.array(X_train)
y_train = np.array(y_train)

mean, std = np.mean(X_train), np.std(X_train)
print(f"{mean=}")
print(f"{std=}")

X_test = []
y_test = []
for pil_img, label in test_data:
    if label in valid_labels:
        x = np.array(pil_img)
        x = cv2.resize(x, (7, 7))
        X_test.append(x.astype(float))
        y_test.append(label)

X_test = np.array(X_test)
y_test = np.array(y_test)

X_test = (X_test - mean) / std
X_train = (X_train - mean) / std


model = KAN(width=[x.shape[0] ** 2, 20, len(valid_labels)], grid=5, k=3, device=device)

dataset = {}
dataset["train_input"] = (
    torch.flatten(torch.from_numpy(X_train), start_dim=1).long().to(device)
)
dataset["train_label"] = torch.from_numpy(y_train).long().to(device)

dataset["test_input"] = (
    torch.flatten(torch.from_numpy(X_test), start_dim=1).long().to(device)
)
dataset["test_label"] = torch.from_numpy(y_test).long().to(device)

loss_fn = torch.nn.CrossEntropyLoss()

result = model.train(
    dataset,
    opt="Adam",
    steps=50,
    lr=0.1,
    batch=512,
    # metrics=(
    #     train_acc,
    #     test_acc,
    # ),  # this is the slower step, so its better to evaluate it after training
    loss_fn=loss_fn,
    # device=device,
)

acc = test_acc()
print(f"Test accuracy: {acc.item()}")


plt.plot(result["train_loss"], label="train_loss")
plt.plot(result["test_loss"], label="test_loss")
plt.ylim(0, 5)
plt.legend()
plt.savefig("loss.png")

from pykan.

juntaoJianggavin avatar juntaoJianggavin commented on August 17, 2024 8

Also let me know, how can I integrate and train KAN layers with CNNs after flattening the tensors?? Anybody please share the code.

I tried to replace MLP with KAN in CNN models, and the performances are close to each other.

https://github.com/juntaoJianggavin/kan-cifar10/tree/main

from pykan.

MeDenTec avatar MeDenTec commented on August 17, 2024 7

Also let me know, how can I integrate and train KAN layers with CNNs after flattening the tensors?? Anybody please share the code.

from pykan.

xiaol avatar xiaol commented on August 17, 2024 6

According to my experiments, the modified version of Kan outperformed MLP with the same shape on the MNIST dataset , both 768 64 10, using the efficient kan code above with some tweaks.
this is kan+
image

this is mlp
W62P J A%@(1OLSLP}Y1G1T

from pykan.

Fredrik00 avatar Fredrik00 commented on August 17, 2024 5

I also think a pure KAN implementation for computer vision does not look very promising due to not making any use of spatial locality. I think an interesting idea could be to define a KAN based 2d convolution layer that replaces the 2d kernel with a spline (or full KAN layer?) working on flattened 2d patches of similar sizes to the regular kernels. At small enough kernel sizes (say 3x3) the loss in fine-grained spatial locality might not as detrimental to model performance.

from pykan.

AlexBodner avatar AlexBodner commented on August 17, 2024 5

We implemented the KAN Convolutional Layers, check out our repo based in the efficient-kan implementation:
https://github.com/AntonioTepsich/Convolutional-KANs

from pykan.

Menghuan1918 avatar Menghuan1918 commented on August 17, 2024 4

Hi, here's my attempted code, its going to take about 30s to run on CUDA and get about 83% accuracy.

import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt

def preprocess_data(data):
    images = []
    labels = []
    for img, label in data:
        img = cv2.resize(np.array(img), (7, 7))
        img = img.flatten() / 255.0
        images.append(img)
        labels.append(label)
    return np.array(images), np.array(labels)

train_data = torchvision.datasets.MNIST(
    root="./mnist_data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
    root="./mnist_data", train=False, download=True, transform=None
)

train_images, train_labels = preprocess_data(train_data)
test_images, test_labels = preprocess_data(test_data)

device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using {device} device")

dataset = {
    "train_input": torch.from_numpy(train_images).float().to(device),
    "train_label": torch.from_numpy(train_labels).to(device),
    "test_input": torch.from_numpy(test_images).float().to("cpu"),
    "test_label": torch.from_numpy(test_labels).to("cpu"),
}

model = KAN(width=[49, 10, 10], device=device)

results = model.train(
    dataset,
    opt="Adam",
    lr=0.05,
    steps=100,
    batch=512,
    loss_fn=torch.nn.CrossEntropyLoss(),
)
torch.save(model.state_dict(), "kan.pth")


del model
model = KAN(width=[49, 10, 10], device="cpu")
model.load_state_dict(torch.load("kan.pth"))

def test_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["test_input"]), dim=1)
        correct = (predictions == dataset["test_label"]).float()
        accuracy = correct.mean()
    return accuracy

acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")

plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()
plt.savefig("kan.png")

Output

from pykan.

GeorgeDeac avatar GeorgeDeac commented on August 17, 2024 4

Probably, after all, the representation power of KANs depends a lot on the distribution and shape of the data. Spline representation imposes some constraints to some shapes making them harder to represent, in contrast to MLPs which don't care that much about extreme shapes. There are still cases where I think KANs consistently outperform MLPs, but I guess it depends a lot on the data domain we are dealing with.

I also saw some implementations that use RBFs instead of splines for KANs, I imagine that RBFs are kinda similar, they would be better compared to MLPs if our data contains gaussian shapes and has some normality.
https://github.com/ZiyaoLi/fast-kan

Also saw many more, that replace the kernel from the standard KAN with different polynomial representations (like Chebyshev) or even saw a wavelet kernel:
https://github.com/SynodicMonth/ChebyKAN
https://github.com/mlsquare/xKAN

But at the end of the day, I think all of these are biased towards better representing certain domains / data shapes and might not universally scale in all cases (depending on the data).

It would be beneficial to make a synthetic data benchmark with some examples of consistent extreme shapes / gradients or other edge cases, and test all these architectures against MLPs.

from pykan.

XiangboGaoBarry avatar XiangboGaoBarry commented on August 17, 2024 3

Hi, here I implement ConvKAN with different activation formulations with their corresponding inference time. https://github.com/XiangboGaoBarry/ConvKAN-Zoo
We evaluate the result on CIFAR10 dataset.

from pykan.

tommarvoloriddle avatar tommarvoloriddle commented on August 17, 2024 3

We are trying to use KAN on ViT to replace MLP for training on ImageNet, and we welcome co-builds!
Vision-KAN

from pykan.

KindXiaoming avatar KindXiaoming commented on August 17, 2024 2

yeah I think KANs, as they are right now, cannot handle convolution. It seems reasonable to defineConvKAN layers. Given the current implementation, the only thing you can do with vision tasks is flattening a whole image into a vector, totally abandoning spatial information (which is not good, that's why I think extra development is needed).

from pykan.

KindXiaoming avatar KindXiaoming commented on August 17, 2024 2

As a quick cute example, you may try play with KAN as if playing with an MLP for MNIST.

Please make sure input data have shape [data size, indim], indim=784. Also, the input dimension of KAN should be 784, and output should be 10. So e.g., these KANs are valid for MNIST: KAN(width=[784,5,10]) or KAN(width=[784,5,5,10]). Also you may want to include say batch=128 in model.train() to train on batches rather than the whole dataset (which is fine, but I worry it might run too slowly on cpu haha).

from pykan.

cpellet avatar cpellet commented on August 17, 2024 2

No unfortunately, @WuZhuoran's comments make sense, I was solely making a point about getting training on mps to work

from pykan.

genglinxiao avatar genglinxiao commented on August 17, 2024 2

Very interesting. I'd really like to see a direct comparison between KAN and MLP in CNN architecture.

from pykan.

Shomvel avatar Shomvel commented on August 17, 2024 2

Compared KAN with 5x smaller MLP. In 10 epochs, KAN reached 91% acc whereas MLP reached 97%. KAN loss goes down more slowly than that of MLP.

# implementation from https://github.com/Blealtan/efficient-kan
class EKAN(nn.Module):
    pass

# a simple MLP model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.layers(x)

# Data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors and scale to [0,1]
    transforms.Normalize((0.5,), (0.5,))  # Normalize to mean=0.5, std=0.5
])

# Load the datasets
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model, loss function, and optimizer
model = EKAN([28*28, 64, 10]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training the model
def train_model(num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            images = images.view(images.shape[0], -1)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}')

train_model(10)

# Testing the model
def test_model():
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            images = images.view(images.shape[0], -1)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Test Accuracy: {accuracy:.2f}%')

test_model()

from pykan.

zdx3578 avatar zdx3578 commented on August 17, 2024 2

before use conv2d, what about use VAE latent space ,train KAN MNIST use VAE encode output as KAN input ?

from pykan.

MeDenTec avatar MeDenTec commented on August 17, 2024 1

Hi everybody, please let me know if anybody of you successfully applied KANs to any Computer vision tasks? or anybody integrated it with CNNs ?

from pykan.

KindXiaoming avatar KindXiaoming commented on August 17, 2024 1

My experience with MNIST is that a 2-Layer KAN with an extremely small (say 5 or 10) hidden neurons is enough to train MNIST (but maybe my impression was from accuracy), i.e., KAN(width=[49, 10, 3]) in your case. It's likely that accuracies are high but losses are high.

So please try computing acc as well. You can refer to this tutorial to see how to do this. Basically, it's something like

def train_acc():
    return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).float())

def test_acc():
    return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())

results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc));
results['train_acc'][-1], results['test_acc'][-1]

from pykan.

HaiFengZeng avatar HaiFengZeng commented on August 17, 2024 1

I think a combination of them(nn.Linear,KAN) works fine for the MNIST task:

import torchvision
import torch
from torchvision import transforms 
import torch.nn as nn
import torch.nn.functional as F
from kan import KAN
import tqdm
transform = transforms.Compose(
    [transforms.ToTensor(),
    #  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    ]
    )

trainset = torchvision.datasets.MNIST(root='./MNIST', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=500,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./MNIST', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=500,
                                         shuffle=False, num_workers=2)
print(len(trainset),len(testset))
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28,64).cuda()
        self.kan = KAN(width=[64,16,10], grid=5, k=3, seed=0,device='cuda:0')
    
    def forward(self,x):
        x = self.linear(x)
        out = self.kan(x)
        return out


net = Net().cuda()

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.002,)

for epoch in range(4):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in tqdm.tqdm(enumerate(trainloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        # print('predict.size=',pred.size())
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        x = inputs.view(inputs.size(0),-1).cuda()
        outputs = net(x)
        loss = criterion(outputs, labels.cuda())
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i %100 == 99:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
    print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
    correct = 0
    total = 0
    # net.eval()
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            # calculate outputs by running images through the network
            x = inputs.view(inputs.size(0),-1).cuda()
            outputs = net(x)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.cuda()).sum().item()

        print(f'epoch {epoch} Accuracy of the network on the 10000 test images: {100 * correct // total} %')
    # net.train()
print('Finished Training')

After 4 epochs training, acc comes to 96%, the logs looks like:

60000 10000
99it [00:52,  1.90it/s][1,   100] loss: 0.028
120it [01:03,  1.89it/s]
[1,   120] loss: 0.002
epoch 0 Accuracy of the network on the 10000 test images: 93 %
99it [00:51,  1.95it/s][2,   100] loss: 0.010
120it [01:02,  1.91it/s]
[2,   120] loss: 0.002
epoch 1 Accuracy of the network on the 10000 test images: 95 %
99it [00:51,  1.91it/s][3,   100] loss: 0.006
120it [01:02,  1.93it/s]
[3,   120] loss: 0.001
epoch 2 Accuracy of the network on the 10000 test images: 95 %
99it [00:51,  1.91it/s][4,   100] loss: 0.005
120it [01:02,  1.93it/s]
[4,   120] loss: 0.001
epoch 3 Accuracy of the network on the 10000 test images: 96 %
Finished Training

from pykan.

WuZhuoran avatar WuZhuoran commented on August 17, 2024

Update on this topic:

I write a short notebook to test traiing and evaluation on MNIST dataset. And If we want to apply KAN on 2D or 3D task, One possible way is to change KANlayer inherits nn.Conv2d?

Here is the screen shot:

image

And here is the Traceback:

description:   0%|                                                           | 0/20 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 1
----> 1 results = model.train(dataset, opt="LBFGS", steps=20, loss_fn=torch.nn.CrossEntropyLoss());

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:913, in KAN.train(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, stop_grid_update_step, batch, small_mag_threshold, small_reg_factor, metrics, sglr_avoid, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, device)
    910 test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
    912 if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
--> 913     self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
    916 if opt == "LBFGS":
    917     optimizer.step(closure)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:242, in KAN.update_grid_from_samples(self, x)
    219 '''
    220 update grid from samples
    221 
   (...)
    239 tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
    240 '''
    241 for l in range(self.depth):
--> 242     self.forward(x)
    243     self.act_fun[l].update_grid_from_samples(self.acts[l])

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:313, in KAN.forward(self, x)
    308 self.acts.append(x) # acts shape: (batch, width[l])
    311 for l in range(self.depth):
--> 313     x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
    315     if self.symbolic_enabled == True:
    316         x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KANLayer.py:172, in KANLayer.forward(self, x)
    170 batch = x.shape[0]
    171 # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
--> 172 x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim,).to(self.device)).reshape(batch, self.size).permute(1,0)
    173 preacts = x.permute(1,0).clone().reshape(batch, self.out_dim, self.in_dim)
    174 base = self.base_fun(x).permute(1,0) # shape (batch, size)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/functional.py:385, in einsum(*args)
    380     return einsum(equation, *_operands)
    382 if len(operands) <= 2 or not opt_einsum.enabled:
    383     # the path for contracting 0 or 1 time(s) is already optimized
    384     # or the user has disabled using opt_einsum
--> 385     return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    387 path = None
    388 if opt_einsum.is_available():

RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given

from pykan.

WuZhuoran avatar WuZhuoran commented on August 17, 2024

As a quick cute example, you may try play with KAN as if playing with an MLP for MNIST.

Please make sure input data have shape [data size, indim], indim=784. Also, the input dimension of KAN should be 784, and output should be 10. So e.g., these KANs are valid for MNIST: KAN(width=[784,5,10]) or KAN(width=[784,5,5,10]). Also you may want to include say batch=128 in model.train() to train on batches rather than the whole dataset (which is fine, but I worry it might run too slowly on cpu haha).

Thank for the quick reply.

It did work with

model = KAN(width=[784,5,5,10], grid=3, k=3).to(device)

and

dataset['train_input'] = torch.flatten(train_dataset.data, start_dim=1).to(device)
dataset['test_input'] = torch.flatten(test_dataset.data, start_dim=1).to(device)

Now training can work on device cpu (slow as expected). But it will raise error when using Apple Chip with device mps

results = model.train(dataset, opt="LBFGS", steps=20, loss_fn=torch.nn.CrossEntropyLoss(), batch=128, device='cpu');

But anyway, faltten the image into 1D is not a good idea in general. A VisionKAN or KAN_Conv2d need to be implemented. LOL.

from pykan.

KindXiaoming avatar KindXiaoming commented on August 17, 2024

Nice! yes, there's still some issue with GPU training. Looking forward to your new development :-)

from pykan.

WuZhuoran avatar WuZhuoran commented on August 17, 2024

Nice! yes, there's still some issue with GPU training. Looking forward to your new development :-)

Yeah. about GPU Training. I might need to use CUDA first. For MPS, it will raised this error for the Classificaion Example:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!

I already make datasets and model on mps device:

dataset['train_input'] = torch.from_numpy(train_input).to(torch.float32).to(device)
dataset['test_input'] = torch.from_numpy(test_input).to(torch.float32).to(device)
dataset['train_label'] = torch.from_numpy(train_label[:,None]).to(torch.float32).to(device)
dataset['test_label'] = torch.from_numpy(test_label[:,None]).to(torch.float32).to(device)

model = KAN(width=[2,1], grid=3, k=3).to(torch.float32).to(device)

results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), device=device);

The full traceback is:

description:   0%|                                                           | 0/20 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 9
      6 def test_acc():
      7     return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())
----> 9 results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), device=device);
     10 results['train_acc'][-1], results['test_acc'][-1]

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:913, in KAN.train(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, stop_grid_update_step, batch, small_mag_threshold, small_reg_factor, metrics, sglr_avoid, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, device)
    910 test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
    912 if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
--> 913     self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
    916 if opt == "LBFGS":
    917     optimizer.step(closure)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:242, in KAN.update_grid_from_samples(self, x)
    219 '''
    220 update grid from samples
    221 
   (...)
    239 tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
    240 '''
    241 for l in range(self.depth):
--> 242     self.forward(x)
    243     self.act_fun[l].update_grid_from_samples(self.acts[l])

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KAN.py:313, in KAN.forward(self, x)
    308 self.acts.append(x) # acts shape: (batch, width[l])
    311 for l in range(self.depth):
--> 313     x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
    315     if self.symbolic_enabled == True:
    316         x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/kan/KANLayer.py:172, in KANLayer.forward(self, x)
    170 batch = x.shape[0]
    171 # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
--> 172 x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim,).to(self.device)).reshape(batch, self.size).permute(1,0)
    173 preacts = x.permute(1,0).clone().reshape(batch, self.out_dim, self.in_dim)
    174 base = self.base_fun(x).permute(1,0) # shape (batch, size)

File ~/miniconda3/envs/kan_py39/lib/python3.9/site-packages/torch/functional.py:385, in einsum(*args)
    380     return einsum(equation, *_operands)
    382 if len(operands) <= 2 or not opt_einsum.enabled:
    383     # the path for contracting 0 or 1 time(s) is already optimized
    384     # or the user has disabled using opt_einsum
--> 385     return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
    387 path = None
    388 if opt_einsum.is_available():

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!

from pykan.

cpellet avatar cpellet commented on August 17, 2024

It turns out that replacing model = KAN(width=[784,5,5,10], grid=3, k=3).to(device) by model = KAN(width=[784,5,5,10], grid=3, k=3, device=device) does the trick for me! Here is a full example training on mps for reference:

from kan import *
from tensorflow import keras

device = "mps"
model = KAN(width=[7*7, 5, 5, 128], grid=3, k=3, device=device)

(X_train,y_train),(X_test,y_test) = keras.datasets.mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0

# downsample to 7x7
X_train = np.array([cv2.resize(x, (7,7)) for x in X_train])
X_test = np.array([cv2.resize(x, (7,7)) for x in X_test])

dataset = {}
dataset['train_input'] = torch.flatten(torch.from_numpy(X_train), start_dim=1).to(torch.float32).to(device)
dataset['train_label'] = torch.from_numpy(y_train).to(torch.float32).to(device)
dataset['test_input'] = torch.flatten(torch.from_numpy(X_test), start_dim=1).to(torch.float32).to(device)
dataset['test_label'] = torch.from_numpy(y_test).to(torch.float32).to(device)

model.train(dataset, opt="LBFGS", steps=20, batch=128)

from pykan.

noahvandal avatar noahvandal commented on August 17, 2024

Were you able to actually train on MNIST using a flat dataset?

from pykan.

MiXaiLL76 avatar MiXaiLL76 commented on August 17, 2024

I tried something like this, but it didn’t work
loss decreases slowly

import cv2
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt

train_data = torchvision.datasets.MNIST(
    root="./mnist_data", train=True, download=True, transform=None
)
test_data = torchvision.datasets.MNIST(
    root="./mnist_data", train=False, download=True, transform=None
)

valid_labels = [0, 1, 2]

X_train = []
y_train = []

for pil_img, label in train_data:
    if label in valid_labels:
        x = np.array(pil_img)
        x = cv2.resize(x, (7, 7))
        X_train.append(x.astype(float))
        y_train.append(label)

X_train = np.array(X_train)
y_train = np.array(y_train)

mean, std = np.mean(X_train), np.std(X_train)
print(f"{mean=}")
print(f"{std=}")

X_test = []
y_test = []
for pil_img, label in test_data:
    if label in valid_labels:
        x = np.array(pil_img)
        x = cv2.resize(x, (7, 7))
        X_test.append(x.astype(float))
        y_test.append(label)

X_test = np.array(X_test)
y_test = np.array(y_test)

X_test = (X_test - mean) / std
X_train = (X_train - mean) / std

device = "cpu"
model = KAN(width=[x.shape[0]**2, 20, 20, len(valid_labels)], grid=3, k=3, device=device)

dataset = {}
dataset["train_input"] = (
    torch.flatten(torch.from_numpy(X_train), start_dim=1).to(torch.float32).to(device)
)
dataset["train_label"] = torch.from_numpy(y_train).to(torch.float32).to(device)
dataset["test_input"] = (
    torch.flatten(torch.from_numpy(X_test), start_dim=1).to(torch.float32).to(device)
)
dataset["test_label"] = torch.from_numpy(y_test).to(torch.float32).to(device)

result = model.train(dataset, opt="Adam", steps=100, lr=0.1, batch=len(valid_labels), device=device)

plt.plot(result['train_loss'], label="train_loss")
plt.plot(result['test_loss'], label="test_loss")
plt.ylim(0, 5)
plt.legend()
plt.show()

from pykan.

MeDenTec avatar MeDenTec commented on August 17, 2024

Very interesting. I'd really like to see a direct comparison between KAN and MLP in CNN architecture.

I am also willing to do so, but don't know how to integrate and train simultaneously

from pykan.

WuZhuoran avatar WuZhuoran commented on August 17, 2024

Were you able to actually train on MNIST using a flat dataset?

Hi,

I did train on MNIST dataset but it is just flatten the image into 1D vector. I think we still need more development on Computer Vision Task.

from pykan.

WuZhuoran avatar WuZhuoran commented on August 17, 2024

I also think a pure KAN implementation for computer vision does not look very promising due to not making any use of spatial locality. I think an interesting idea could be to define a KAN based 2d convolution layer that replaces the 2d kernel with a spline (or full KAN layer?) working on flattened 2d patches of similar sizes to the regular kernels. At small enough kernel sizes (say 3x3) the loss in fine-grained spatial locality might not as detrimental to model performance.

Good point on 2D conv layer. One possible is to define a kan_conv2d layer. then we can build KAN3D or KAN directly with different conv2d layer.

Currently all the test on images (such as MNIST) that are processing the data into 1D vector which is not very useful.

from pykan.

zdx3578 avatar zdx3578 commented on August 17, 2024

@xiaol use Handwritten Sequence Trajectories?

from pykan.

Uljibuh avatar Uljibuh commented on August 17, 2024

how can i build a Conv-KAN ? how do i integrate convolotion into KAN ?

from pykan.

paulestano avatar paulestano commented on August 17, 2024

I used a 'linearized version' of nn.Conv2d using nn.Unfold and a reshape to build a KANConv2d I'm not completely sure whether it makes sense and I don't think it's efficient at all but you may check it out

from pykan.

SimoSbara avatar SimoSbara commented on August 17, 2024

I also tried a simple implementation of LeNet but with KAN as classifier: https://github.com/SimoSbara/kan-lenet

KAN receives flatten data from convolution.

from pykan.

SimoSbara avatar SimoSbara commented on August 17, 2024

I think a combination of them(nn.Linear,KAN) works fine for the MNIST task:

import torchvision
import torch
from torchvision import transforms 
import torch.nn as nn
import torch.nn.functional as F
from kan import KAN
import tqdm
transform = transforms.Compose(
    [transforms.ToTensor(),
    #  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    ]
    )

trainset = torchvision.datasets.MNIST(root='./MNIST', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=500,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./MNIST', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=500,
                                         shuffle=False, num_workers=2)
print(len(trainset),len(testset))
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28,64).cuda()
        self.kan = KAN(width=[64,16,10], grid=5, k=3, seed=0,device='cuda:0')
    
    def forward(self,x):
        x = self.linear(x)
        out = self.kan(x)
        return out


net = Net().cuda()

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.002,)

for epoch in range(4):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in tqdm.tqdm(enumerate(trainloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        # print('predict.size=',pred.size())
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        x = inputs.view(inputs.size(0),-1).cuda()
        outputs = net(x)
        loss = criterion(outputs, labels.cuda())
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i %100 == 99:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
    print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
    correct = 0
    total = 0
    # net.eval()
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            # calculate outputs by running images through the network
            x = inputs.view(inputs.size(0),-1).cuda()
            outputs = net(x)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.cuda()).sum().item()

        print(f'epoch {epoch} Accuracy of the network on the 10000 test images: {100 * correct // total} %')
    # net.train()
print('Finished Training')

After 4 epochs training, acc comes to 96%, the logs looks like:

60000 10000
99it [00:52,  1.90it/s][1,   100] loss: 0.028
120it [01:03,  1.89it/s]
[1,   120] loss: 0.002
epoch 0 Accuracy of the network on the 10000 test images: 93 %
99it [00:51,  1.95it/s][2,   100] loss: 0.010
120it [01:02,  1.91it/s]
[2,   120] loss: 0.002
epoch 1 Accuracy of the network on the 10000 test images: 95 %
99it [00:51,  1.91it/s][3,   100] loss: 0.006
120it [01:02,  1.93it/s]
[3,   120] loss: 0.001
epoch 2 Accuracy of the network on the 10000 test images: 95 %
99it [00:51,  1.91it/s][4,   100] loss: 0.005
120it [01:02,  1.93it/s]
[4,   120] loss: 0.001
epoch 3 Accuracy of the network on the 10000 test images: 96 %
Finished Training

In comparison with MLP its a good improvement.
Although in real cases the convolution gives real robustness in OCR applications.

It would be nice to have a peformance benchmark for bigger nets where kan replaces mlp.

from pykan.

hesamsheikh avatar hesamsheikh commented on August 17, 2024

As a bit of experiment, i tried training KAN on MNIST:

def create_kan():
    return KAN(width=[7**2, 3, 10], grid=3, k=3)
model = create_kan()

def test_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["test_input"]), dim=1)
        correct = (predictions == dataset["test_label"]).float()
        accuracy = correct.mean()
    return accuracy

def train_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["train_input"].to('cpu')), dim=1)
        correct = (predictions == dataset["train_label"].to('cpu')).float()
        accuracy = correct.mean()
    return accuracy

# Train the model
results = model.train(
    dataset,
    opt="LBFGS",
    steps=20,
    batch=512,
    loss_fn=torch.nn.CrossEntropyLoss(),
    metrics=(train_acc, test_acc)
)
torch.save(model.state_dict(), "kan.pth")

del model
model = create_kan()
model.load_state_dict(torch.load("kan.pth"))

acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")

plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()

I get 81% accuracy with a KAN of 10640 parameters.
image

Doing the same experiment, I'm getting 91% accuracy on a FullyConnected Network with 15306 parameters,

import torch.nn as nn
import torch.optim as optim

class FullyConnectedNN(nn.Module):
    def __init__(self):
        super(FullyConnectedNN, self).__init__()
        self.fc1 = nn.Linear(7*7, 128)  # 7*7 is the size of the resized and flattened image
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # Output 10 classes for MNIST digits

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_and_evaluate(model, train_data, train_labels, test_data, test_labels, epochs=20, batch_size=512):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    for epoch in range(epochs):
        for i in range(0, len(train_data), batch_size):
            inputs = train_data[i:i+batch_size]
            labels = train_labels[i:i+batch_size]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
        test_acc = evaluate_accuracy(model, test_data, test_labels)
        print(f'Test Accuracy: {test_acc}')
    
def evaluate_accuracy(model, data, labels):
    with torch.no_grad():
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == labels).float()
        accuracy = correct.mean()
    return accuracy

# Create and train the model
fcnn_model = FullyConnectedNN()
print(sum(p.numel() for p in fcnn_model.parameters()))
train_and_evaluate(fcnn_model, train_tensor, train_labels_tensor, test_tensor, test_labels_tensor)

this is somewhat far off from the scaling benefits of KAN over MLP in the experiments. Now, I flattening an image for a vision task is not the best practice but it's an equal setting for KAN and MLP. So what is your take?

from pykan.

GeorgeDeac avatar GeorgeDeac commented on August 17, 2024

As a bit of experiment, i tried training KAN on MNIST:

def create_kan():
    return KAN(width=[7**2, 3, 10], grid=3, k=3)
model = create_kan()

def test_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["test_input"]), dim=1)
        correct = (predictions == dataset["test_label"]).float()
        accuracy = correct.mean()
    return accuracy

def train_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["train_input"].to('cpu')), dim=1)
        correct = (predictions == dataset["train_label"].to('cpu')).float()
        accuracy = correct.mean()
    return accuracy

# Train the model
results = model.train(
    dataset,
    opt="LBFGS",
    steps=20,
    batch=512,
    loss_fn=torch.nn.CrossEntropyLoss(),
    metrics=(train_acc, test_acc)
)
torch.save(model.state_dict(), "kan.pth")

del model
model = create_kan()
model.load_state_dict(torch.load("kan.pth"))

acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")

plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()

I get 81% accuracy with a KAN of 10640 parameters.
image

Doing the same experiment, I'm getting 91% accuracy on a FullyConnected Network with 15306 parameters,

import torch.nn as nn
import torch.optim as optim

class FullyConnectedNN(nn.Module):
    def __init__(self):
        super(FullyConnectedNN, self).__init__()
        self.fc1 = nn.Linear(7*7, 128)  # 7*7 is the size of the resized and flattened image
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # Output 10 classes for MNIST digits

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_and_evaluate(model, train_data, train_labels, test_data, test_labels, epochs=20, batch_size=512):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    for epoch in range(epochs):
        for i in range(0, len(train_data), batch_size):
            inputs = train_data[i:i+batch_size]
            labels = train_labels[i:i+batch_size]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
        test_acc = evaluate_accuracy(model, test_data, test_labels)
        print(f'Test Accuracy: {test_acc}')
    
def evaluate_accuracy(model, data, labels):
    with torch.no_grad():
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == labels).float()
        accuracy = correct.mean()
    return accuracy

# Create and train the model
fcnn_model = FullyConnectedNN()
print(sum(p.numel() for p in fcnn_model.parameters()))
train_and_evaluate(fcnn_model, train_tensor, train_labels_tensor, test_tensor, test_labels_tensor)

this is somewhat far off from the scaling benefits of KAN over MLP in the experiments. Now, I flattening an image for a vision task is not the best practice but it's an equal setting for KAN and MLP. So what is your take?

Maybe the distribution of the data in the flattened vector is harder to be represented with splines vs. the universal approximation theorem (the perceptron). I would imagine that flattening an image to a single vector could give very sudden and local differences across instances, aka smaller granularity which might be inherently harder to represent with splines?

Edit:
Here's is the distribution of the input data we are actually trying to learn from that flattened vector

image

Which corresponds to this heatmap in the non-flattened image:

image

And these are the ranges of pixel intensities:

image

So yeah, there's a lot of sudden jumps

from pykan.

hesamsheikh avatar hesamsheikh commented on August 17, 2024

As a bit of experiment, i tried training KAN on MNIST:

def create_kan():
    return KAN(width=[7**2, 3, 10], grid=3, k=3)
model = create_kan()

def test_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["test_input"]), dim=1)
        correct = (predictions == dataset["test_label"]).float()
        accuracy = correct.mean()
    return accuracy

def train_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["train_input"].to('cpu')), dim=1)
        correct = (predictions == dataset["train_label"].to('cpu')).float()
        accuracy = correct.mean()
    return accuracy

# Train the model
results = model.train(
    dataset,
    opt="LBFGS",
    steps=20,
    batch=512,
    loss_fn=torch.nn.CrossEntropyLoss(),
    metrics=(train_acc, test_acc)
)
torch.save(model.state_dict(), "kan.pth")

del model
model = create_kan()
model.load_state_dict(torch.load("kan.pth"))

acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")

plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()

I get 81% accuracy with a KAN of 10640 parameters.
image
Doing the same experiment, I'm getting 91% accuracy on a FullyConnected Network with 15306 parameters,

import torch.nn as nn
import torch.optim as optim

class FullyConnectedNN(nn.Module):
    def __init__(self):
        super(FullyConnectedNN, self).__init__()
        self.fc1 = nn.Linear(7*7, 128)  # 7*7 is the size of the resized and flattened image
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # Output 10 classes for MNIST digits

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_and_evaluate(model, train_data, train_labels, test_data, test_labels, epochs=20, batch_size=512):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    for epoch in range(epochs):
        for i in range(0, len(train_data), batch_size):
            inputs = train_data[i:i+batch_size]
            labels = train_labels[i:i+batch_size]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
        test_acc = evaluate_accuracy(model, test_data, test_labels)
        print(f'Test Accuracy: {test_acc}')
    
def evaluate_accuracy(model, data, labels):
    with torch.no_grad():
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == labels).float()
        accuracy = correct.mean()
    return accuracy

# Create and train the model
fcnn_model = FullyConnectedNN()
print(sum(p.numel() for p in fcnn_model.parameters()))
train_and_evaluate(fcnn_model, train_tensor, train_labels_tensor, test_tensor, test_labels_tensor)

this is somewhat far off from the scaling benefits of KAN over MLP in the experiments. Now, I flattening an image for a vision task is not the best practice but it's an equal setting for KAN and MLP. So what is your take?

Maybe the distribution of the data in the flattened vector is harder to be represented with splines vs. the universal approximation theorem (the perceptron). I would imagine that flattening an image to a single vector could give very sudden and local differences across instances, aka smaller granularity which might be inherently harder to represent with splines?

Edit: Here's is the distribution of the input data we are actually trying to learn from that flattened vector

image

Which corresponds to this heatmap in the non-flattened image:

image

And these are the ranges of pixel intensities:

image

So yeah, there's a lot of sudden jumps

I was expecting much more resistance to sudden local jumps from splines. this is what i would infer from the continual learning section of the paper as splines preserve local information much more than MLPs.

image

I guess @KindXiaoming would have some idea about this.

from pykan.

GeorgeDeac avatar GeorgeDeac commented on August 17, 2024

As a bit of experiment, i tried training KAN on MNIST:

def create_kan():
    return KAN(width=[7**2, 3, 10], grid=3, k=3)
model = create_kan()

def test_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["test_input"]), dim=1)
        correct = (predictions == dataset["test_label"]).float()
        accuracy = correct.mean()
    return accuracy

def train_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["train_input"].to('cpu')), dim=1)
        correct = (predictions == dataset["train_label"].to('cpu')).float()
        accuracy = correct.mean()
    return accuracy

# Train the model
results = model.train(
    dataset,
    opt="LBFGS",
    steps=20,
    batch=512,
    loss_fn=torch.nn.CrossEntropyLoss(),
    metrics=(train_acc, test_acc)
)
torch.save(model.state_dict(), "kan.pth")

del model
model = create_kan()
model.load_state_dict(torch.load("kan.pth"))

acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")

plt.plot(results["train_loss"], label="train")
plt.plot(results["test_loss"], label="test")
plt.legend()

I get 81% accuracy with a KAN of 10640 parameters.
image
Doing the same experiment, I'm getting 91% accuracy on a FullyConnected Network with 15306 parameters,

import torch.nn as nn
import torch.optim as optim

class FullyConnectedNN(nn.Module):
    def __init__(self):
        super(FullyConnectedNN, self).__init__()
        self.fc1 = nn.Linear(7*7, 128)  # 7*7 is the size of the resized and flattened image
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # Output 10 classes for MNIST digits

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_and_evaluate(model, train_data, train_labels, test_data, test_labels, epochs=20, batch_size=512):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    for epoch in range(epochs):
        for i in range(0, len(train_data), batch_size):
            inputs = train_data[i:i+batch_size]
            labels = train_labels[i:i+batch_size]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
        test_acc = evaluate_accuracy(model, test_data, test_labels)
        print(f'Test Accuracy: {test_acc}')
    
def evaluate_accuracy(model, data, labels):
    with torch.no_grad():
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == labels).float()
        accuracy = correct.mean()
    return accuracy

# Create and train the model
fcnn_model = FullyConnectedNN()
print(sum(p.numel() for p in fcnn_model.parameters()))
train_and_evaluate(fcnn_model, train_tensor, train_labels_tensor, test_tensor, test_labels_tensor)

this is somewhat far off from the scaling benefits of KAN over MLP in the experiments. Now, I flattening an image for a vision task is not the best practice but it's an equal setting for KAN and MLP. So what is your take?

Maybe the distribution of the data in the flattened vector is harder to be represented with splines vs. the universal approximation theorem (the perceptron). I would imagine that flattening an image to a single vector could give very sudden and local differences across instances, aka smaller granularity which might be inherently harder to represent with splines?
Edit: Here's is the distribution of the input data we are actually trying to learn from that flattened vector
image
Which corresponds to this heatmap in the non-flattened image:
image
And these are the ranges of pixel intensities:
image
So yeah, there's a lot of sudden jumps

I was expecting much more resistance to sudden local jumps from splines. this is what i would infer from the continual learning section of the paper as splines preserve local information much more than MLPs.

image

I guess @KindXiaoming would have some idea about this.

I would just guess that there might be a representation power limit given by how small and sudden the inflexions are, given the number of parameters we have for the splines? I would also like to investigate the actual reason tbh

from pykan.

hesamsheikh avatar hesamsheikh commented on August 17, 2024

We implemented the KAN Convolutional Layers, check out our repo based in the efficient-kan implementation: https://github.com/AntonioTepsich/Convolutional-KANs

your results also point out in the case of MNIST KAN isn't able to scale as much as promised in the paper, essentially being in the same level of MLP regarding parameters.

from pykan.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.