整个五一假期都没有碰代码(除了第一天捣鼓了一点小玩意)。到了假期末的时候刷了一下Github的热榜,看到一个国人的开源仓库冲上了日榜,也就是今天要学习的内容,KANs。
作者是斯坦福大学的博士生刘子鸣,他将自己开发的网络命名为Kolmogorov-Arnold Networks。在他的博客上介绍了自己,毕业于北大物理系,并曾在微软亚院实习。目前他主要研究的方向是传统物理和AI的交叉学科,看起来很笼统。看了一下Google Scholar,引用最高的几篇文章偏向机器学习方面。
该文的命名来源于Kolmogorov-Arnold表示定理,此前我完全没有了解过。看了很多网上的解释,我理解为任何多变量连续 函数都可以表示为多个单变量、加法连续函数的有限组合。
公式如下:$ f(\mathbf{x}) = f(x_1, \dots, x_n) = \sum_{q=0}^{2n} \Phi_q \left( \sum_{p=1}^n \phi_{qp}(x_p) \right) $
这里等号后面的括号里和括号外就是从$[0,1]$的$R$的连续函数。括号里是内部函数,外则是外部函数。连续函数可以是线性变换函数或者二次函数等等。
这个理论在机器学习领域可以简化为:学习高维函数的过程可以简化成学习多项式数量的一维函数。KANs的想法则是替代前馈网络:$ \mathbf{y} = \sigma(\mathbf{Wx} + \mathbf{b}) $
为什么之前在机器学习中没有被人们所使用?论文中给出了自己的解释:
有人可能天真地认为这对机器学习来说是个好消息:学习高维函数归结为学习多项式数量的一维函数。然而,这些一维函数可能是非光滑的,甚至是分形的,因此在实践中可能无法学习。由于这种病态行为,科尔莫戈洛夫-阿诺德表示定理在机器学习中基本上被判了死刑,被认为在理论上是正确的,但在实践中是无用的。
拿Github上的原图来展示一下:
在前馈网络中,最终需要拟合的函数由多个线性函数($W$权重)以及非线性函数(激活函数)组合而成。而在KANs则变为了$KAN(x)=(\Phi_3\circ \Phi_2 \circ \Phi_1)(x)$。
上图同时还展示了另外一点,那就是KANs网络的可学习参数要比MLP要少了很多。其中$\Phi_2$用来实现非线性函数。
代码实现 作者的工程能力很强,提供了基于Pytorch写的框架,不过据他所说,目前代码还有一些不足。而目前Github社区也有人迅速跟进。目前这个名为efficent-kan
的项目已经获得了超过两千颗星,链接:Blealtan/efficient-kan: An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN) 。
看了一眼源码,代码量不多,但是需要比较深的数学背景,所以暂时跳过这个部分。
MNIST数据集 项目用MNIST数据集进行了测试训练,代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 from efficient_kan import KANimport torchimport torch.nn as nnimport torch.optim as optimimport torchvisionimport torchvision.transforms as transformsfrom torch.utils.data import DataLoaderfrom tqdm import tqdm transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5 ,), (0.5 ,))] ) trainset = torchvision.datasets.MNIST( root="./data" , train=True , download=True , transform=transform ) valset = torchvision.datasets.MNIST( root="./data" , train=False , download=True , transform=transform ) trainloader = DataLoader(trainset, batch_size=64 , shuffle=True ) valloader = DataLoader(valset, batch_size=64 , shuffle=False ) model = KAN([28 * 28 , 64 , 10 ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu" ) model.to(device) optimizer = optim.AdamW(model.parameters(), lr=1e-3 , weight_decay=1e-4 ) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8 ) criterion = nn.CrossEntropyLoss()for epoch in range (10 ): model.train() with tqdm(trainloader) as pbar: for i, (images, labels) in enumerate (pbar): images = images.view(-1 , 28 * 28 ).to(device) optimizer.zero_grad() output = model(images) loss = criterion(output, labels.to(device)) loss.backward() optimizer.step() accuracy = (output.argmax(dim=1 ) == labels.to(device)).float ().mean() pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0 ]['lr' ]) model.eval () val_loss = 0 val_accuracy = 0 with torch.no_grad(): for images, labels in valloader: images = images.view(-1 , 28 * 28 ).to(device) output = model(images) val_loss += criterion(output, labels.to(device)).item() val_accuracy += ( (output.argmax(dim=1 ) == labels.to(device)).float ().mean().item() ) val_loss /= len (valloader) val_accuracy /= len (valloader) scheduler.step() print ( f"Epoch {epoch + 1 } , Val Loss: {val_loss} , Val Accuracy: {val_accuracy} " )
从结果上来看,KANs在收敛速度上比传统的MLP要快,精度相差不大。不过MLP存在过拟合的问题。数据集过小。还需要在更多的场景验证。
总的来说,KANs架构带来最大的两个优势:
计算复杂度降低带来的收敛速度提高。
动态图结构(区别于MLP的静态图结构中固定的激活函数)更加灵活,大家普遍认为这能解决灾难性遗忘的问题,因为训练时,较远的权重参数之间不会有太大的影响。
2024/5/11 于苏州