Skip to content

Da1sypetals/cuda-Wavelet-KAN

Repository files navigation

CUDA implementation of Wavelet KAN

  • See Also other CUDA implementations of KAN:

  • I am interested in the performance aspect of KAN, and willing to discuss / recieve more information on this topic :)

  • This is for personal practice purposes, use at your own risk. Tested on my RTX3050 as well as a remote RTX3090 on CUDA 12.x .

Update:

  • A much faster implementation is updated, but please note:
    • Since the implementation uses tiling, assertions are opened by default to make sure that tensor conform to the restrictions.
    • If a NaN emerges during training, please first check that all dimensions are divisible by 128. If that does not solve, feel free to open an issue(but chances are high that I cannot solve either).
    • Currently only forward is optimized, but optimizing backward is mathematically similar. Maybe it will be done after my examinations...
    • I am a cuda beginner, and I am grateful for any optimization suggestion : )
  • Thanks https://github.com/siboehm/SGEMM_CUDA for the optimized SGEMM code, I adopted it with some modification and got the implementation.

results on RTX3050:

          |      forward  |     backward  |      forward  |     backward  |   num params  |  num trainable params
---------------------------------------------------------------------------------------------------------------------------
cuda-gpu  |    117.24 ms  |    651.21 ms  |      1.10 GB  |      4.12 GB  |     12787840  |              12787840
gemm-gpu  |     26.21 ms  |    678.70 ms  |      0.10 GB  |      4.12 GB  |     12787840  |              12787840

Introduction

CUDA implementation of the paper introducing Wavelet KAN at https://arxiv.org/abs/2405.12832.

This is significantly faster than the original implementation, with ~50x performance forward and 5x performance backward, results given by benchmark scripts in https://github.com/Jerry-Master/KAN-benchmarking.

          |      forward  |     backward  |      forward  |     backward  |   num params  |  num trainable params
--------------------------------------------------------------------------------------------------------------------
cuda-gpu  |     29.10 ms  |     65.27 ms  |      0.28 GB  |      1.03 GB  |      3151362  |               3151362
orig-gpu  |    522.00 ms  |   1461.29 ms  |      5.53 GB  |      5.53 GB  |      3151362  |               3151362

Note

  • There are no optimizations in this implementation. I am a cuda beginner and willing to receive optimization suggestions : )

  • Currently Mexican hat and Morlet are implemented.

Start

  1. Install
pip install -e .

Make sure the version of nvcc in PATH is compatible with your current PyTorch version (it seems minor version difference is OK).

  1. Run

    • Run test on MNIST:
    python test.py
  2. Benchmark

python benchmark.py --method all --reps 10 --just-cuda

Please remind:

  1. Morlet wavelet performs badly in MNIST, but if you use a shallow net, you can observe it learn.

About

CUDA implementation of Wavelet KAN.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published