Git Product home page Git Product logo

cnn-kan's Introduction

CNN-KAN

Single epoch CNN+KAN trial on MNIST with 96% accuracy. You could give it a shot with the Notebook or via the Colab link.

Remarks

  • Loss is very unstable.

Maybe its related to parameter settings, however whenever I try few epochs (3+) weights kept exploding.

Train Epoch: 0 [0000/6000 (00%)]	Loss: 1.662824
Train Epoch: 0 [0640/6000 (11%)]	Loss: 0.500479
Train Epoch: 0 [1280/6000 (21%)]	Loss: 0.287909
Train Epoch: 0 [1920/6000 (32%)]	Loss: 0.221105
Train Epoch: 0 [2560/6000 (43%)]	Loss: 0.107704
Train Epoch: 0 [3200/6000 (53%)]	Loss: 0.146229
Train Epoch: 0 [3840/6000 (64%)]	Loss: 0.077108
Train Epoch: 0 [4480/6000 (74%)]	Loss: 0.157832
Train Epoch: 0 [5120/6000 (85%)]	Loss: 0.075530
Train Epoch: 0 [5760/6000 (96%)]	Loss: 0.064992

Test set: Average loss: 0.0007, Accuracy: 9620/10000 (96%)
  • Other backprop algorithms I tried did not work.
  • SeLu instead of ReLu like usual.
  • 2 KAN layers were more stable than a single one.

Acknowledgement

Kolmogorov Arnold Networks (Original work): pyKAN

Fourier coefficient instead of splines: FourierKAN

cnn-kan's People

Contributors

eonurk avatar

Stargazers

 avatar  avatar Jiameng Lv avatar Wenzheng Yao avatar  avatar

Watchers

 avatar

cnn-kan's Issues

loss?

Hello, I'm sorry to bother you. I want to ask you during the training process encountered full memory situation?(I fixed this error in the following way
) What is the problem that the loss of this code does not decrease during training?

loss_fn = nn.CrossEntropyLoss()

# Training function
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data.to(device))
        loss = loss_fn(output, target.to(device))
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        def closure():
            optimizer.zero_grad()
            output = model(data.to(device))
            loss = nn.CrossEntropyLoss()(output, target.to(device))
            loss.backward()
            return loss
        optimizer.step(closure)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.