diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/README.md b/models/others/kolmogorov_arnold_networks/kan/pytorch/README.md index 4a63b5d4978e40d18240f8dd5549d6f1bb4a2a6f..db3966a57a5874a821f9d7b803da6d386e8e81d6 100644 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/README.md +++ b/models/others/kolmogorov_arnold_networks/kan/pytorch/README.md @@ -12,6 +12,7 @@ both model accuracy and interpretability. | GPU | [IXUCA SDK](https://gitee.com/deep-spark/deepspark#%E5%A4%A9%E6%95%B0%E6%99%BA%E7%AE%97%E8%BD%AF%E4%BB%B6%E6%A0%88-ixuca) | Release | | :----: | :----: | :----: | +| BI-V150 | 4.3.0 | 25.12 | | BI-V150 | 4.1.1 | 24.12 | ## Model Preparation @@ -20,12 +21,39 @@ both model accuracy and interpretability. ```bash pip3 install -r requirements.txt +git clone https://github.com/KindXiaoming/pykan.git +cd pykan +``` + +Change kan/spline.py line 117: + +```python +try: + coef = torch.linalg.lstsq(mat, y_eval).solution[:,:,:,0] +except: + print('lstsq failed') +``` + +with below + +```python +try: + coef = torch.linalg.lstsq(mat.cpu(), y_eval.cpu()).solution[:,:,:,0] + coef = coef.to(device) +except: + print('lstsq failed') +``` + +then + +```bash +pip install -e . ``` ## Model Training ```bash -bash ./run_train.sh +python3 ./train_kan.py --steps 100 ``` ## Model Results diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/__init__.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/KANLayer-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/KANLayer-checkpoint.py deleted file mode 100644 index b880bfe8b58b69edd1b7d5dfe7234ec27d6af8ec..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/KANLayer-checkpoint.py +++ /dev/null @@ -1,364 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -from .spline import * -from .utils import sparse_mask - - -class KANLayer(nn.Module): - """ - KANLayer class - - - Attributes: - ----------- - in_dim: int - input dimension - out_dim: int - output dimension - num: int - the number of grid intervals - k: int - the piecewise polynomial order of splines - noise_scale: float - spline scale at initialization - coef: 2D torch.tensor - coefficients of B-spline bases - scale_base_mu: float - magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu - scale_base_sigma: float - magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma - scale_sp: float - mangitude of the spline function spline(x) - base_fun: fun - residual function b(x) - mask: 1D torch.float - mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function. - grid_eps: float in [0,1] - a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. - the id of activation functions that are locked - device: str - device - """ - - def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False): - '''' - initialize a KANLayer - - Args: - ----- - in_dim : int - input dimension. Default: 2. - out_dim : int - output dimension. Default: 3. - num : int - the number of grid intervals = G. Default: 5. - k : int - the order of piecewise polynomial. Default: 3. - noise_scale : float - the scale of noise injected at initialization. Default: 0.1. - scale_base_mu : float - the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - scale_base_sigma : float - the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - scale_sp : float - the scale of the base function spline(x). - base_fun : function - residual function b(x). Default: torch.nn.SiLU() - grid_eps : float - When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. - grid_range : list/np.array of shape (2,) - setting the range of grids. Default: [-1,1]. - sp_trainable : bool - If true, scale_sp is trainable - sb_trainable : bool - If true, scale_base is trainable - device : str - device - sparse_init : bool - if sparse_init = True, sparse initialization is applied. - - Returns: - -------- - self - - Example - ------- - >>> from kan.KANLayer import * - >>> model = KANLayer(in_dim=3, out_dim=5) - >>> (model.in_dim, model.out_dim) - ''' - super(KANLayer, self).__init__() - # size - self.out_dim = out_dim - self.in_dim = in_dim - self.num = num - self.k = k - - grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None,:].expand(self.in_dim, num+1) - grid = extend_grid(grid, k_extend=k) - self.grid = torch.nn.Parameter(grid).requires_grad_(False) - noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1/2) * noise_scale / num - - self.coef = torch.nn.Parameter(curve2coef(self.grid[:,k:-k].permute(1,0), noises, self.grid, k)) - - if sparse_init: - self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(False) - else: - self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(False) - - self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \ - scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable) - self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask).requires_grad_(sp_trainable) # make scale trainable - self.base_fun = base_fun - - - self.grid_eps = grid_eps - - self.to(device) - - def to(self, device): - super(KANLayer, self).to(device) - self.device = device - return self - - def forward(self, x): - ''' - KANLayer forward given input x - - Args: - ----- - x : 2D torch.float - inputs, shape (number of samples, input dimension) - - Returns: - -------- - y : 2D torch.float - outputs, shape (number of samples, output dimension) - preacts : 3D torch.float - fan out x into activations, shape (number of sampels, output dimension, input dimension) - postacts : 3D torch.float - the outputs of activation functions with preacts as inputs - postspline : 3D torch.float - the outputs of spline functions with preacts as inputs - - Example - ------- - >>> from kan.KANLayer import * - >>> model = KANLayer(in_dim=3, out_dim=5) - >>> x = torch.normal(0,1,size=(100,3)) - >>> y, preacts, postacts, postspline = model(x) - >>> y.shape, preacts.shape, postacts.shape, postspline.shape - ''' - batch = x.shape[0] - preacts = x[:,None,:].clone().expand(batch, self.out_dim, self.in_dim) - - base = self.base_fun(x) # (batch, in_dim) - y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k) - - postspline = y.clone().permute(0,2,1) - - y = self.scale_base[None,:,:] * base[:,:,None] + self.scale_sp[None,:,:] * y - y = self.mask[None,:,:] * y - - postacts = y.clone().permute(0,2,1) - - y = torch.sum(y, dim=1) - return y, preacts, postacts, postspline - - def update_grid_from_samples(self, x, mode='sample'): - ''' - update grid from samples - - Args: - ----- - x : 2D torch.float - inputs, shape (number of samples, input dimension) - - Returns: - -------- - None - - Example - ------- - >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) - >>> print(model.grid.data) - >>> x = torch.linspace(-3,3,steps=100)[:,None] - >>> model.update_grid_from_samples(x) - >>> print(model.grid.data) - ''' - - batch = x.shape[0] - #x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0) - x_pos = torch.sort(x, dim=0)[0] - y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) - num_interval = self.grid.shape[1] - 1 - 2*self.k - - def get_grid(num_interval): - ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] - grid_adaptive = x_pos[ids, :].permute(1,0) - margin = 0.00 - h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]] + 2 * margin)/num_interval - grid_uniform = grid_adaptive[:,[0]] - margin + h * torch.arange(num_interval+1,)[None, :].to(x.device) - grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive - return grid - - - grid = get_grid(num_interval) - - if mode == 'grid': - sample_grid = get_grid(2*num_interval) - x_pos = sample_grid.permute(1,0) - y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) - - self.grid.data = extend_grid(grid, k_extend=self.k) - #print('x_pos 2', x_pos.shape) - #print('y_eval 2', y_eval.shape) - self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) - - def initialize_grid_from_parent(self, parent, x, mode='sample'): - ''' - update grid from a parent KANLayer & samples - - Args: - ----- - parent : KANLayer - a parent KANLayer (whose grid is usually coarser than the current model) - x : 2D torch.float - inputs, shape (number of samples, input dimension) - - Returns: - -------- - None - - Example - ------- - >>> batch = 100 - >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) - >>> print(parent_model.grid.data) - >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3) - >>> x = torch.normal(0,1,size=(batch, 1)) - >>> model.initialize_grid_from_parent(parent_model, x) - >>> print(model.grid.data) - ''' - - batch = x.shape[0] - - # shrink grid - x_pos = torch.sort(x, dim=0)[0] - y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) - num_interval = self.grid.shape[1] - 1 - 2*self.k - - - ''' - # based on samples - def get_grid(num_interval): - ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] - grid_adaptive = x_pos[ids, :].permute(1,0) - h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval - grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device) - grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive - return grid''' - - #print('p', parent.grid) - # based on interpolating parent grid - def get_grid(num_interval): - x_pos = parent.grid[:,parent.k:-parent.k] - #print('x_pos', x_pos) - sp2 = KANLayer(in_dim=1, out_dim=self.in_dim,k=1,num=x_pos.shape[1]-1,scale_base_mu=0.0, scale_base_sigma=0.0).to(x.device) - - #print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim)) - #print('sp2_coef_shape', sp2.coef.shape) - sp2_coef = curve2coef(sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim), x_pos.permute(1,0).unsqueeze(dim=2), sp2.grid[:,:], k=1).permute(1,0,2) - shp = sp2_coef.shape - #sp2_coef = torch.cat([torch.zeros(shp[0], shp[1], 1), sp2_coef, torch.zeros(shp[0], shp[1], 1)], dim=2) - #print('sp2_coef',sp2_coef) - #print(sp2.coef.shape) - sp2.coef.data = sp2_coef - percentile = torch.linspace(-1,1,self.num+1).to(self.device) - grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1,0) - #print('c', grid) - return grid - - grid = get_grid(num_interval) - - if mode == 'grid': - sample_grid = get_grid(2*num_interval) - x_pos = sample_grid.permute(1,0) - y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) - - grid = extend_grid(grid, k_extend=self.k) - self.grid.data = grid - self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) - - def get_subset(self, in_id, out_id): - ''' - get a smaller KANLayer from a larger KANLayer (used for pruning) - - Args: - ----- - in_id : list - id of selected input neurons - out_id : list - id of selected output neurons - - Returns: - -------- - spb : KANLayer - - Example - ------- - >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3) - >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3]) - >>> kanlayer_small.in_dim, kanlayer_small.out_dim - (2, 3) - ''' - spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun) - spb.grid.data = self.grid[in_id] - spb.coef.data = self.coef[in_id][:,out_id] - spb.scale_base.data = self.scale_base[in_id][:,out_id] - spb.scale_sp.data = self.scale_sp[in_id][:,out_id] - spb.mask.data = self.mask[in_id][:,out_id] - - spb.in_dim = len(in_id) - spb.out_dim = len(out_id) - return spb - - - def swap(self, i1, i2, mode='in'): - ''' - swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out') - - Args: - ----- - i1 : int - i2 : int - mode : str - mode = 'in' or 'out' - - Returns: - -------- - None - - Example - ------- - >>> from kan.KANLayer import * - >>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3) - >>> print(model.coef) - >>> model.swap(0,1,mode='in') - >>> print(model.coef) - ''' - with torch.no_grad(): - def swap_(data, i1, i2, mode='in'): - if mode == 'in': - data[i1], data[i2] = data[i2].clone(), data[i1].clone() - elif mode == 'out': - data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone() - - if mode == 'in': - swap_(self.grid.data, i1, i2, mode='in') - swap_(self.coef.data, i1, i2, mode=mode) - swap_(self.scale_base.data, i1, i2, mode=mode) - swap_(self.scale_sp.data, i1, i2, mode=mode) - swap_(self.mask.data, i1, i2, mode=mode) - diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/LBFGS-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/LBFGS-checkpoint.py deleted file mode 100644 index 212477f23ec325ad80e4f2e849db8c895b380045..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/LBFGS-checkpoint.py +++ /dev/null @@ -1,493 +0,0 @@ -import torch -from functools import reduce -from torch.optim import Optimizer - -__all__ = ['LBFGS'] - -def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): - # ported from https://github.com/torch/optim/blob/master/polyinterp.lua - # Compute bounds of interpolation area - if bounds is not None: - xmin_bound, xmax_bound = bounds - else: - xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) - - # Code for most common case: cubic interpolation of 2 points - # w/ function and derivative values for both - # Solution in this case (where x2 is the farthest point): - # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); - # d2 = sqrt(d1^2 - g1*g2); - # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); - # t_new = min(max(min_pos,xmin_bound),xmax_bound); - d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) - d2_square = d1**2 - g1 * g2 - if d2_square >= 0: - d2 = d2_square.sqrt() - if x1 <= x2: - min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) - else: - min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) - return min(max(min_pos, xmin_bound), xmax_bound) - else: - return (xmin_bound + xmax_bound) / 2. - - -def _strong_wolfe(obj_func, - x, - t, - d, - f, - g, - gtd, - c1=1e-4, - c2=0.9, - tolerance_change=1e-9, - max_ls=25): - # ported from https://github.com/torch/optim/blob/master/lswolfe.lua - d_norm = d.abs().max() - g = g.clone(memory_format=torch.contiguous_format) - # evaluate objective and gradient using initial step - f_new, g_new = obj_func(x, t, d) - ls_func_evals = 1 - gtd_new = g_new.dot(d) - - # bracket an interval containing a point satisfying the Wolfe criteria - t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd - done = False - ls_iter = 0 - while ls_iter < max_ls: - # check conditions - #print(f_prev, f_new, g_new) - if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): - bracket = [t_prev, t] - bracket_f = [f_prev, f_new] - bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] - bracket_gtd = [gtd_prev, gtd_new] - break - - if abs(gtd_new) <= -c2 * gtd: - bracket = [t] - bracket_f = [f_new] - bracket_g = [g_new] - done = True - break - - if gtd_new >= 0: - bracket = [t_prev, t] - bracket_f = [f_prev, f_new] - bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] - bracket_gtd = [gtd_prev, gtd_new] - break - - # interpolate - min_step = t + 0.01 * (t - t_prev) - max_step = t * 10 - tmp = t - t = _cubic_interpolate( - t_prev, - f_prev, - gtd_prev, - t, - f_new, - gtd_new, - bounds=(min_step, max_step)) - - # next step - t_prev = tmp - f_prev = f_new - g_prev = g_new.clone(memory_format=torch.contiguous_format) - gtd_prev = gtd_new - f_new, g_new = obj_func(x, t, d) - ls_func_evals += 1 - gtd_new = g_new.dot(d) - ls_iter += 1 - - - # reached max number of iterations? - if ls_iter == max_ls: - bracket = [0, t] - bracket_f = [f, f_new] - bracket_g = [g, g_new] - - # zoom phase: we now have a point satisfying the criteria, or - # a bracket around it. We refine the bracket until we find the - # exact point satisfying the criteria - insuf_progress = False - # find high and low points in bracket - low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) - while not done and ls_iter < max_ls: - # line-search bracket is so small - if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: - break - - # compute new trial value - t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0], - bracket[1], bracket_f[1], bracket_gtd[1]) - - # test that we are making sufficient progress: - # in case `t` is so close to boundary, we mark that we are making - # insufficient progress, and if - # + we have made insufficient progress in the last step, or - # + `t` is at one of the boundary, - # we will move `t` to a position which is `0.1 * len(bracket)` - # away from the nearest boundary point. - eps = 0.1 * (max(bracket) - min(bracket)) - if min(max(bracket) - t, t - min(bracket)) < eps: - # interpolation close to boundary - if insuf_progress or t >= max(bracket) or t <= min(bracket): - # evaluate at 0.1 away from boundary - if abs(t - max(bracket)) < abs(t - min(bracket)): - t = max(bracket) - eps - else: - t = min(bracket) + eps - insuf_progress = False - else: - insuf_progress = True - else: - insuf_progress = False - - # Evaluate new point - f_new, g_new = obj_func(x, t, d) - ls_func_evals += 1 - gtd_new = g_new.dot(d) - ls_iter += 1 - - if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: - # Armijo condition not satisfied or not lower than lowest point - bracket[high_pos] = t - bracket_f[high_pos] = f_new - bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) - bracket_gtd[high_pos] = gtd_new - low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) - else: - if abs(gtd_new) <= -c2 * gtd: - # Wolfe conditions satisfied - done = True - elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: - # old low becomes new high - bracket[high_pos] = bracket[low_pos] - bracket_f[high_pos] = bracket_f[low_pos] - bracket_g[high_pos] = bracket_g[low_pos] - bracket_gtd[high_pos] = bracket_gtd[low_pos] - - # new point becomes new low - bracket[low_pos] = t - bracket_f[low_pos] = f_new - bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) - bracket_gtd[low_pos] = gtd_new - - #print(bracket) - if len(bracket) == 1: - t = bracket[0] - f_new = bracket_f[0] - g_new = bracket_g[0] - else: - t = bracket[low_pos] - f_new = bracket_f[low_pos] - g_new = bracket_g[low_pos] - return f_new, g_new, t, ls_func_evals - - - -class LBFGS(Optimizer): - """Implements L-BFGS algorithm. - - Heavily inspired by `minFunc - `_. - - .. warning:: - This optimizer doesn't support per-parameter options and parameter - groups (there can be only one). - - .. warning:: - Right now all parameters have to be on a single device. This will be - improved in the future. - - .. note:: - This is a very memory intensive optimizer (it requires additional - ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory - try reducing the history size, or use a different algorithm. - - Args: - lr (float): learning rate (default: 1) - max_iter (int): maximal number of iterations per optimization step - (default: 20) - max_eval (int): maximal number of function evaluations per optimization - step (default: max_iter * 1.25). - tolerance_grad (float): termination tolerance on first order optimality - (default: 1e-7). - tolerance_change (float): termination tolerance on function - value/parameter changes (default: 1e-9). - history_size (int): update history size (default: 100). - line_search_fn (str): either 'strong_wolfe' or None (default: None). - """ - - def __init__(self, - params, - lr=1, - max_iter=20, - max_eval=None, - tolerance_grad=1e-7, - tolerance_change=1e-9, - tolerance_ys=1e-32, - history_size=100, - line_search_fn=None): - if max_eval is None: - max_eval = max_iter * 5 // 4 - defaults = dict( - lr=lr, - max_iter=max_iter, - max_eval=max_eval, - tolerance_grad=tolerance_grad, - tolerance_change=tolerance_change, - tolerance_ys=tolerance_ys, - history_size=history_size, - line_search_fn=line_search_fn) - super().__init__(params, defaults) - - if len(self.param_groups) != 1: - raise ValueError("LBFGS doesn't support per-parameter options " - "(parameter groups)") - - self._params = self.param_groups[0]['params'] - self._numel_cache = None - - def _numel(self): - if self._numel_cache is None: - self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) - return self._numel_cache - - def _gather_flat_grad(self): - views = [] - for p in self._params: - if p.grad is None: - view = p.new(p.numel()).zero_() - elif p.grad.is_sparse: - view = p.grad.to_dense().view(-1) - else: - view = p.grad.view(-1) - views.append(view) - device = views[0].device - return torch.cat(views, dim=0) - - def _add_grad(self, step_size, update): - offset = 0 - for p in self._params: - numel = p.numel() - # view as to avoid deprecated pointwise semantics - p.add_(update[offset:offset + numel].view_as(p), alpha=step_size) - offset += numel - assert offset == self._numel() - - def _clone_param(self): - return [p.clone(memory_format=torch.contiguous_format) for p in self._params] - - def _set_param(self, params_data): - for p, pdata in zip(self._params, params_data): - p.copy_(pdata) - - def _directional_evaluate(self, closure, x, t, d): - self._add_grad(t, d) - loss = float(closure()) - flat_grad = self._gather_flat_grad() - self._set_param(x) - return loss, flat_grad - - - @torch.no_grad() - def step(self, closure): - """Perform a single optimization step. - - Args: - closure (Callable): A closure that reevaluates the model - and returns the loss. - """ - - torch.manual_seed(0) - - assert len(self.param_groups) == 1 - - # Make sure the closure is always called with grad enabled - closure = torch.enable_grad()(closure) - - group = self.param_groups[0] - lr = group['lr'] - max_iter = group['max_iter'] - max_eval = group['max_eval'] - tolerance_grad = group['tolerance_grad'] - tolerance_change = group['tolerance_change'] - tolerance_ys = group['tolerance_ys'] - line_search_fn = group['line_search_fn'] - history_size = group['history_size'] - - # NOTE: LBFGS has only global state, but we register it as state for - # the first param, because this helps with casting in load_state_dict - state = self.state[self._params[0]] - state.setdefault('func_evals', 0) - state.setdefault('n_iter', 0) - - # evaluate initial f(x) and df/dx - orig_loss = closure() - loss = float(orig_loss) - current_evals = 1 - state['func_evals'] += 1 - - flat_grad = self._gather_flat_grad() - opt_cond = flat_grad.abs().max() <= tolerance_grad - - # optimal condition - if opt_cond: - return orig_loss - - # tensors cached in state (for tracing) - d = state.get('d') - t = state.get('t') - old_dirs = state.get('old_dirs') - old_stps = state.get('old_stps') - ro = state.get('ro') - H_diag = state.get('H_diag') - prev_flat_grad = state.get('prev_flat_grad') - prev_loss = state.get('prev_loss') - - n_iter = 0 - # optimize for a max of max_iter iterations - while n_iter < max_iter: - # keep track of nb of iterations - n_iter += 1 - state['n_iter'] += 1 - - ############################################################ - # compute gradient descent direction - ############################################################ - if state['n_iter'] == 1: - d = flat_grad.neg() - old_dirs = [] - old_stps = [] - ro = [] - H_diag = 1 - else: - # do lbfgs update (update memory) - y = flat_grad.sub(prev_flat_grad) - s = d.mul(t) - ys = y.dot(s) # y*s - if ys > tolerance_ys: - # updating memory - if len(old_dirs) == history_size: - # shift history by one (limited-memory) - old_dirs.pop(0) - old_stps.pop(0) - ro.pop(0) - - # store new direction/step - old_dirs.append(y) - old_stps.append(s) - ro.append(1. / ys) - - # update scale of initial Hessian approximation - H_diag = ys / y.dot(y) # (y*y) - - # compute the approximate (L-BFGS) inverse Hessian - # multiplied by the gradient - num_old = len(old_dirs) - - if 'al' not in state: - state['al'] = [None] * history_size - al = state['al'] - - # iteration in L-BFGS loop collapsed to use just one buffer - q = flat_grad.neg() - for i in range(num_old - 1, -1, -1): - al[i] = old_stps[i].dot(q) * ro[i] - q.add_(old_dirs[i], alpha=-al[i]) - - # multiply by initial Hessian - # r/d is the final direction - d = r = torch.mul(q, H_diag) - for i in range(num_old): - be_i = old_dirs[i].dot(r) * ro[i] - r.add_(old_stps[i], alpha=al[i] - be_i) - - if prev_flat_grad is None: - prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) - else: - prev_flat_grad.copy_(flat_grad) - prev_loss = loss - - ############################################################ - # compute step length - ############################################################ - # reset initial guess for step size - if state['n_iter'] == 1: - t = min(1., 1. / flat_grad.abs().sum()) * lr - else: - t = lr - - # directional derivative - gtd = flat_grad.dot(d) # g * d - - # directional derivative is below tolerance - if gtd > -tolerance_change: - break - - # optional line search: user function - ls_func_evals = 0 - if line_search_fn is not None: - # perform line search, using user function - if line_search_fn != "strong_wolfe": - raise RuntimeError("only 'strong_wolfe' is supported") - else: - x_init = self._clone_param() - - def obj_func(x, t, d): - return self._directional_evaluate(closure, x, t, d) - loss, flat_grad, t, ls_func_evals = _strong_wolfe( - obj_func, x_init, t, d, loss, flat_grad, gtd) - self._add_grad(t, d) - opt_cond = flat_grad.abs().max() <= tolerance_grad - else: - # no line search, simply move with fixed-step - self._add_grad(t, d) - if n_iter != max_iter: - # re-evaluate function only if not in last iteration - # the reason we do this: in a stochastic setting, - # no use to re-evaluate that function here - with torch.enable_grad(): - loss = float(closure()) - flat_grad = self._gather_flat_grad() - opt_cond = flat_grad.abs().max() <= tolerance_grad - ls_func_evals = 1 - - # update func eval - current_evals += ls_func_evals - state['func_evals'] += ls_func_evals - - ############################################################ - # check conditions - ############################################################ - if n_iter == max_iter: - break - - if current_evals >= max_eval: - break - - # optimal condition - if opt_cond: - break - - # lack of progress - if d.mul(t).abs().max() <= tolerance_change: - break - - if abs(loss - prev_loss) < tolerance_change: - break - - state['d'] = d - state['t'] = t - state['old_dirs'] = old_dirs - state['old_stps'] = old_stps - state['ro'] = ro - state['H_diag'] = H_diag - state['prev_flat_grad'] = prev_flat_grad - state['prev_loss'] = prev_loss - - return orig_loss diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/MLP-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/MLP-checkpoint.py deleted file mode 100644 index 1066c3b3db20684b86c6fc2794c8e1c0330b3967..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/MLP-checkpoint.py +++ /dev/null @@ -1,361 +0,0 @@ -import torch -import torch.nn as nn -import matplotlib.pyplot as plt -import numpy as np -from tqdm import tqdm -from .LBFGS import LBFGS - -seed = 0 -torch.manual_seed(seed) - -class MLP(nn.Module): - - def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'): - super(MLP, self).__init__() - - torch.manual_seed(seed) - - linears = [] - self.width = width - self.depth = depth = len(width) - 1 - for i in range(depth): - linears.append(nn.Linear(width[i], width[i+1])) - self.linears = nn.ModuleList(linears) - - #if activation == 'silu': - self.act_fun = torch.nn.SiLU() - self.save_act = save_act - self.acts = None - - self.cache_data = None - - self.device = device - self.to(device) - - - def to(self, device): - super(MLP, self).to(device) - self.device = device - - return self - - - def get_act(self, x=None): - if isinstance(x, dict): - x = x['train_input'] - if x == None: - if self.cache_data != None: - x = self.cache_data - else: - raise Exception("missing input data x") - save_act = self.save_act - self.save_act = True - self.forward(x) - self.save_act = save_act - - @property - def w(self): - return [self.linears[l].weight for l in range(self.depth)] - - def forward(self, x): - - # cache data - self.cache_data = x - - self.acts = [] - self.acts_scale = [] - self.wa_forward = [] - self.a_forward = [] - - for i in range(self.depth): - - if self.save_act: - act = x.clone() - act_scale = torch.std(x, dim=0) - wa_forward = act_scale[None, :] * self.linears[i].weight - self.acts.append(act) - if i > 0: - self.acts_scale.append(act_scale) - self.wa_forward.append(wa_forward) - - x = self.linears[i](x) - if i < self.depth - 1: - x = self.act_fun(x) - else: - if self.save_act: - act_scale = torch.std(x, dim=0) - self.acts_scale.append(act_scale) - - return x - - def attribute(self): - if self.acts == None: - self.get_act() - - node_scores = [] - edge_scores = [] - - # back propagate from the last layer - node_score = torch.ones(self.width[-1]).requires_grad_(True).to(self.device) - node_scores.append(node_score) - - for l in range(self.depth,0,-1): - - edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l-1]), node_score/(self.acts_scale[l-1]+1e-4)) - edge_scores.append(edge_score) - - # this might be improper for MLPs (although reasonable for KANs) - node_score = torch.sum(edge_score, dim=0)/torch.sqrt(torch.tensor(self.width[l-1], device=self.device)) - #print(self.width[l]) - node_scores.append(node_score) - - self.node_scores = list(reversed(node_scores)) - self.edge_scores = list(reversed(edge_scores)) - self.wa_backward = self.edge_scores - - def plot(self, beta=3, scale=1., metric='w'): - # metric = 'w', 'act' or 'fa' - - if metric == 'fa': - self.attribute() - - depth = self.depth - y0 = 0.5 - fig, ax = plt.subplots(figsize=(3*scale,3*y0*depth*scale)) - shp = self.width - - min_spacing = 1/max(self.width) - for j in range(len(shp)): - N = shp[j] - for i in range(N): - plt.scatter(1 / (2 * N) + i / N, j * y0, s=min_spacing ** 2 * 5000 * scale ** 2, color='black') - - plt.ylim(-0.1*y0,y0*depth+0.1*y0) - plt.xlim(-0.02,1.02) - - linears = self.linears - - for ii in range(len(linears)): - linear = linears[ii] - p = linear.weight - p_shp = p.shape - - if metric == 'w': - pass - elif metric == 'act': - p = self.wa_forward[ii] - elif metric == 'fa': - p = self.wa_backward[ii] - else: - raise Exception('metric = \'{}\' not recognized. Choices are \'w\', \'act\', \'fa\'.'.format(metric)) - for i in range(p_shp[0]): - for j in range(p_shp[1]): - plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*p_shp[1])+j/p_shp[1]], [y0*(ii+1),y0*ii], lw=0.5*scale, alpha=np.tanh(beta*np.abs(p[i,j].cpu().detach().numpy())), color="blue" if p[i,j]>0 else "red") - - ax.axis('off') - - def reg(self, reg_metric, lamb_l1, lamb_entropy): - - if reg_metric == 'w': - acts_scale = self.w - if reg_metric == 'act': - acts_scale = self.wa_forward - if reg_metric == 'fa': - acts_scale = self.wa_backward - if reg_metric == 'a': - acts_scale = self.acts_scale - - if len(acts_scale[0].shape) == 2: - reg_ = 0. - - for i in range(len(acts_scale)): - vec = acts_scale[i] - vec = torch.abs(vec) - - l1 = torch.sum(vec) - p_row = vec / (torch.sum(vec, dim=1, keepdim=True) + 1) - p_col = vec / (torch.sum(vec, dim=0, keepdim=True) + 1) - entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1)) - entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0)) - reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col) - - elif len(acts_scale[0].shape) == 1: - - reg_ = 0. - - for i in range(len(acts_scale)): - vec = acts_scale[i] - vec = torch.abs(vec) - - l1 = torch.sum(vec) - p = vec / (torch.sum(vec) + 1) - entropy = - torch.sum(p * torch.log2(p + 1e-4)) - reg_ += lamb_l1 * l1 + lamb_entropy * entropy - - return reg_ - - def get_reg(self, reg_metric, lamb_l1, lamb_entropy): - return self.reg(reg_metric, lamb_l1, lamb_entropy) - - def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1., batch=-1, - metrics=None, in_vars=None, out_vars=None, beta=3, device='cpu', reg_metric='w', display_metrics=None): - - if lamb > 0. and not self.save_act: - print('setting lamb=0. If you want to set lamb > 0, set =True') - - old_save_act = self.save_act - if lamb == 0.: - self.save_act = False - - pbar = tqdm(range(steps), desc='description', ncols=100) - - if loss_fn == None: - loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2) - else: - loss_fn = loss_fn_eval = loss_fn - - if opt == "Adam": - optimizer = torch.optim.Adam(self.parameters(), lr=lr) - elif opt == "LBFGS": - optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) - - results = {} - results['train_loss'] = [] - results['test_loss'] = [] - results['reg'] = [] - if metrics != None: - for i in range(len(metrics)): - results[metrics[i].__name__] = [] - - if batch == -1 or batch > dataset['train_input'].shape[0]: - batch_size = dataset['train_input'].shape[0] - batch_size_test = dataset['test_input'].shape[0] - else: - batch_size = batch - batch_size_test = batch - - global train_loss, reg_ - - def closure(): - global train_loss, reg_ - optimizer.zero_grad() - pred = self.forward(dataset['train_input'][train_id].to(self.device)) - train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device)) - if self.save_act: - if reg_metric == 'fa': - self.attribute() - reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy) - else: - reg_ = torch.tensor(0.) - objective = train_loss + lamb * reg_ - objective.backward() - return objective - - for _ in pbar: - - if _ == steps-1 and old_save_act: - self.save_act = True - - train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False) - test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False) - - if opt == "LBFGS": - optimizer.step(closure) - - if opt == "Adam": - pred = self.forward(dataset['train_input'][train_id].to(self.device)) - train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device)) - if self.save_act: - reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy) - else: - reg_ = torch.tensor(0.) - loss = train_loss + lamb * reg_ - optimizer.zero_grad() - loss.backward() - optimizer.step() - - test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)), dataset['test_label'][test_id].to(self.device)) - - - if metrics != None: - for i in range(len(metrics)): - results[metrics[i].__name__].append(metrics[i]().item()) - - results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy()) - results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy()) - results['reg'].append(reg_.cpu().detach().numpy()) - - if _ % log == 0: - if display_metrics == None: - pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy())) - else: - string = '' - data = () - for metric in display_metrics: - string += f' {metric}: %.2e |' - try: - results[metric] - except: - raise Exception(f'{metric} not recognized') - data += (results[metric][-1],) - pbar.set_description(string % data) - - return results - - @property - def connection_cost(self): - - with torch.no_grad(): - cc = 0. - for linear in self.linears: - t = torch.abs(linear.weight) - def get_coordinate(n): - return torch.linspace(0,1,steps=n+1, device=self.device)[:n] + 1/(2*n) - - in_dim = t.shape[0] - x_in = get_coordinate(in_dim) - - out_dim = t.shape[1] - x_out = get_coordinate(out_dim) - - dist = torch.abs(x_in[:,None] - x_out[None,:]) - cc += torch.sum(dist * t) - - return cc - - def swap(self, l, i1, i2): - - def swap_row(data, i1, i2): - data[i1], data[i2] = data[i2].clone(), data[i1].clone() - - def swap_col(data, i1, i2): - data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone() - - swap_row(self.linears[l-1].weight.data, i1, i2) - swap_row(self.linears[l-1].bias.data, i1, i2) - swap_col(self.linears[l].weight.data, i1, i2) - - def auto_swap_l(self, l): - - num = self.width[l] - for i in range(num): - ccs = [] - for j in range(num): - self.swap(l,i,j) - self.get_act() - self.attribute() - cc = self.connection_cost.detach().clone() - ccs.append(cc) - self.swap(l,i,j) - j = torch.argmin(torch.tensor(ccs)) - self.swap(l,i,j) - - def auto_swap(self): - depth = self.depth - for l in range(1, depth): - self.auto_swap_l(l) - - def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False): - if x == None: - x = self.cache_data - plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test, verbose=verbose) \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/MultKAN-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/MultKAN-checkpoint.py deleted file mode 100644 index 37f3e58200586b22606f3f15dd1f99f606587568..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/MultKAN-checkpoint.py +++ /dev/null @@ -1,2805 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -from .KANLayer import KANLayer -#from .Symbolic_MultKANLayer import * -from .Symbolic_KANLayer import Symbolic_KANLayer -from .LBFGS import * -import os -import glob -import matplotlib.pyplot as plt -from tqdm import tqdm -import random -import copy -#from .MultKANLayer import MultKANLayer -import pandas as pd -from sympy.printing import latex -from sympy import * -import sympy -import yaml -from .spline import curve2coef -from .utils import SYMBOLIC_LIB -from .hypothesis import plot_tree - -class MultKAN(nn.Module): - ''' - KAN class - - Attributes: - ----------- - grid : int - the number of grid intervals - k : int - spline order - act_fun : a list of KANLayers - symbolic_fun: a list of Symbolic_KANLayer - depth : int - depth of KAN - width : list - number of neurons in each layer. - Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons. - With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2). - mult_arity : int, or list of int lists - multiplication arity for each multiplication node (the number of numbers to be multiplied) - grid : int - the number of grid intervals - k : int - the order of piecewise polynomial - base_fun : fun - residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x) - symbolic_fun : a list of Symbolic_KANLayer - Symbolic_KANLayers - symbolic_enabled : bool - If False, the symbolic front is not computed (to save time). Default: True. - width_in : list - The number of input neurons for each layer - width_out : list - The number of output neurons for each layer - base_fun_name : str - The base function b(x) - grip_eps : float - The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile) - node_bias : a list of 1D torch.float - node_scale : a list of 1D torch.float - subnode_bias : a list of 1D torch.float - subnode_scale : a list of 1D torch.float - symbolic_enabled : bool - when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero) - affine_trainable : bool - indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale) - sp_trainable : bool - indicate whether the overall magnitude of splines is trainable - sb_trainable : bool - indicate whether the overall magnitude of base function is trainable - save_act : bool - indicate whether intermediate activations are saved in forward pass - node_scores : None or list of 1D torch.float - node attribution score - edge_scores : None or list of 2D torch.float - edge attribution score - subnode_scores : None or list of 1D torch.float - subnode attribution score - cache_data : None or 2D torch.float - cached input data - acts : None or a list of 2D torch.float - activations on nodes - auto_save : bool - indicate whether to automatically save a checkpoint once the model is modified - state_id : int - the state of the model (used to save checkpoint) - ckpt_path : str - the folder to store checkpoints - round : int - the number of times rewind() has been called - device : str - ''' - def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'): - ''' - initalize a KAN model - - Args: - ----- - width : list of int - Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs) - With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs) - grid : int - number of grid intervals. Default: 3. - k : int - order of piecewise polynomial. Default: 3. - mult_arity : int, or list of int lists - multiplication arity for each multiplication node (the number of numbers to be multiplied) - noise_scale : float - initial injected noise to spline. - base_fun : str - the residual function b(x). Default: 'silu' - symbolic_enabled : bool - compute (True) or skip (False) symbolic computations (for efficiency). By default: True. - affine_trainable : bool - affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias - grid_eps : float - When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. - grid_range : list/np.array of shape (2,)) - setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True) - sp_trainable : bool - If true, scale_sp is trainable. Default: True. - sb_trainable : bool - If true, scale_base is trainable. Default: True. - device : str - device - seed : int - random seed - save_act : bool - indicate whether intermediate activations are saved in forward pass - sparse_init : bool - sparse initialization (True) or normal dense initialization. Default: False. - auto_save : bool - indicate whether to automatically save a checkpoint once the model is modified - state_id : int - the state of the model (used to save checkpoint) - ckpt_path : str - the folder to store checkpoints. Default: './model' - round : int - the number of times rewind() has been called - device : str - - Returns: - -------- - self - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - checkpoint directory created: ./model - saving model version 0.0 - ''' - super(MultKAN, self).__init__() - - torch.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) - - ### initializeing the numerical front ### - - self.act_fun = [] - self.depth = len(width) - 1 - - #print('haha1', width) - for i in range(len(width)): - #print(type(width[i]), type(width[i]) == int) - if type(width[i]) == int or type(width[i]) == np.int64: - width[i] = [width[i],0] - - #print('haha2', width) - - self.width = width - - # if mult_arity is just a scalar, we extend it to a list of lists - # e.g, mult_arity = [[2,3],[4]] means that in the first hidden layer, 2 mult ops have arity 2 and 3, respectively; - # in the second hidden layer, 1 mult op has arity 4. - if isinstance(mult_arity, int): - self.mult_homo = True # when homo is True, parallelization is possible - else: - self.mult_homo = False # when home if False, for loop is required. - self.mult_arity = mult_arity - - width_in = self.width_in - width_out = self.width_out - - self.base_fun_name = base_fun - if base_fun == 'silu': - base_fun = torch.nn.SiLU() - elif base_fun == 'identity': - base_fun = torch.nn.Identity() - elif base_fun == 'zero': - base_fun = lambda x: x*0. - - self.grid_eps = grid_eps - self.grid_range = grid_range - - - for l in range(self.depth): - # splines - if isinstance(grid, list): - grid_l = grid[l] - else: - grid_l = grid - - if isinstance(k, list): - k_l = k[l] - else: - k_l = k - - - sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid_l, k=k_l, noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init) - self.act_fun.append(sp_batch) - - self.node_bias = [] - self.node_scale = [] - self.subnode_bias = [] - self.subnode_scale = [] - - globals()['self.node_bias_0'] = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False) - exec('self.node_bias_0' + " = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)") - - for l in range(self.depth): - exec(f'self.node_bias_{l} = torch.nn.Parameter(torch.zeros(width_in[l+1])).requires_grad_(affine_trainable)') - exec(f'self.node_scale_{l} = torch.nn.Parameter(torch.ones(width_in[l+1])).requires_grad_(affine_trainable)') - exec(f'self.subnode_bias_{l} = torch.nn.Parameter(torch.zeros(width_out[l+1])).requires_grad_(affine_trainable)') - exec(f'self.subnode_scale_{l} = torch.nn.Parameter(torch.ones(width_out[l+1])).requires_grad_(affine_trainable)') - exec(f'self.node_bias.append(self.node_bias_{l})') - exec(f'self.node_scale.append(self.node_scale_{l})') - exec(f'self.subnode_bias.append(self.subnode_bias_{l})') - exec(f'self.subnode_scale.append(self.subnode_scale_{l})') - - - self.act_fun = nn.ModuleList(self.act_fun) - - self.grid = grid - self.k = k - self.base_fun = base_fun - - ### initializing the symbolic front ### - self.symbolic_fun = [] - for l in range(self.depth): - sb_batch = Symbolic_KANLayer(in_dim=width_in[l], out_dim=width_out[l+1]) - self.symbolic_fun.append(sb_batch) - - self.symbolic_fun = nn.ModuleList(self.symbolic_fun) - self.symbolic_enabled = symbolic_enabled - self.affine_trainable = affine_trainable - self.sp_trainable = sp_trainable - self.sb_trainable = sb_trainable - - self.save_act = save_act - - self.node_scores = None - self.edge_scores = None - self.subnode_scores = None - - self.cache_data = None - self.acts = None - - self.auto_save = auto_save - self.state_id = 0 - self.ckpt_path = ckpt_path - self.round = round - - self.device = device - self.to(device) - - if auto_save: - if first_init: - if not os.path.exists(ckpt_path): - # Create the directory - os.makedirs(ckpt_path) - print(f"checkpoint directory created: {ckpt_path}") - print('saving model version 0.0') - - history_path = self.ckpt_path+'/history.txt' - with open(history_path, 'w') as file: - file.write(f'### Round {self.round} ###' + '\n') - file.write('init => 0.0' + '\n') - self.saveckpt(path=self.ckpt_path+'/'+'0.0') - else: - self.state_id = state_id - - self.input_id = torch.arange(self.width_in[0],) - - def to(self, device): - ''' - move the model to device - - Args: - ----- - device : str or device - - Returns: - -------- - self - - Example - ------- - >>> from kan import * - >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - >>> model.to(device) - ''' - super(MultKAN, self).to(device) - self.device = device - - for kanlayer in self.act_fun: - kanlayer.to(device) - - for symbolic_kanlayer in self.symbolic_fun: - symbolic_kanlayer.to(device) - - return self - - @property - def width_in(self): - ''' - The number of input nodes for each layer - ''' - width = self.width - width_in = [width[l][0]+width[l][1] for l in range(len(width))] - return width_in - - @property - def width_out(self): - ''' - The number of output subnodes for each layer - ''' - width = self.width - if self.mult_homo == True: - width_out = [width[l][0]+self.mult_arity*width[l][1] for l in range(len(width))] - else: - width_out = [width[l][0]+int(np.sum(self.mult_arity[l])) for l in range(len(width))] - return width_out - - @property - def n_sum(self): - ''' - The number of addition nodes for each layer - ''' - width = self.width - n_sum = [width[l][0] for l in range(1,len(width)-1)] - return n_sum - - @property - def n_mult(self): - ''' - The number of multiplication nodes for each layer - ''' - width = self.width - n_mult = [width[l][1] for l in range(1,len(width)-1)] - return n_mult - - @property - def feature_score(self): - ''' - attribution scores for inputs - ''' - self.attribute() - if self.node_scores == None: - return None - else: - return self.node_scores[0] - - def initialize_from_another_model(self, another_model, x): - ''' - initialize from another model of the same width, but their 'grid' parameter can be different. - Note this is equivalent to refine() when we don't want to keep another_model - - Args: - ----- - another_model : MultKAN - x : 2D torch.float - - Returns: - -------- - self - - Example - ------- - >>> from kan import * - >>> model1 = KAN(width=[2,5,1], grid=3) - >>> model2 = KAN(width=[2,5,1], grid=10) - >>> x = torch.rand(100,2) - >>> model2.initialize_from_another_model(model1, x) - ''' - another_model(x) # get activations - batch = x.shape[0] - - self.initialize_grid_from_another_model(another_model, x) - - for l in range(self.depth): - spb = self.act_fun[l] - #spb_parent = another_model.act_fun[l] - - # spb = spb_parent - preacts = another_model.spline_preacts[l] - postsplines = another_model.spline_postsplines[l] - self.act_fun[l].coef.data = curve2coef(preacts[:,0,:], postsplines.permute(0,2,1), spb.grid, k=spb.k) - self.act_fun[l].scale_base.data = another_model.act_fun[l].scale_base.data - self.act_fun[l].scale_sp.data = another_model.act_fun[l].scale_sp.data - self.act_fun[l].mask.data = another_model.act_fun[l].mask.data - - for l in range(self.depth): - self.node_bias[l].data = another_model.node_bias[l].data - self.node_scale[l].data = another_model.node_scale[l].data - - self.subnode_bias[l].data = another_model.subnode_bias[l].data - self.subnode_scale[l].data = another_model.subnode_scale[l].data - - for l in range(self.depth): - self.symbolic_fun[l] = another_model.symbolic_fun[l] - - return self.to(self.device) - - def log_history(self, method_name): - - if self.auto_save: - - # save to log file - #print(func.__name__) - with open(self.ckpt_path+'/history.txt', 'a') as file: - file.write(str(self.round)+'.'+str(self.state_id)+' => '+ method_name + ' => ' + str(self.round)+'.'+str(self.state_id+1) + '\n') - - # update state_id - self.state_id += 1 - - # save to ckpt - self.saveckpt(path=self.ckpt_path+'/'+str(self.round)+'.'+str(self.state_id)) - print('saving model version '+str(self.round)+'.'+str(self.state_id)) - - - def refine(self, new_grid): - ''' - grid refinement - - Args: - ----- - new_grid : init - the number of grid intervals after refinement - - Returns: - -------- - a refined model : MultKAN - - Example - ------- - >>> from kan import * - >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - >>> print(model.grid) - >>> x = torch.rand(100,2) - >>> model.get_act(x) - >>> model = model.refine(10) - >>> print(model.grid) - checkpoint directory created: ./model - saving model version 0.0 - 5 - saving model version 0.1 - 10 - ''' - - model_new = MultKAN(width=self.width, - grid=new_grid, - k=self.k, - mult_arity=self.mult_arity, - base_fun=self.base_fun_name, - symbolic_enabled=self.symbolic_enabled, - affine_trainable=self.affine_trainable, - grid_eps=self.grid_eps, - grid_range=self.grid_range, - sp_trainable=self.sp_trainable, - sb_trainable=self.sb_trainable, - ckpt_path=self.ckpt_path, - auto_save=True, - first_init=False, - state_id=self.state_id, - round=self.round, - device=self.device) - - model_new.initialize_from_another_model(self, self.cache_data) - model_new.cache_data = self.cache_data - model_new.grid = new_grid - - self.log_history('refine') - model_new.state_id += 1 - - return model_new.to(self.device) - - - def saveckpt(self, path='model'): - ''' - save the current model to files (configuration file and state file) - - Args: - ----- - path : str - the path where checkpoints are saved - - Returns: - -------- - None - - Example - ------- - >>> from kan import * - >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - >>> model.saveckpt('./mark') - # There will be three files appearing in the current folder: mark_cache_data, mark_config.yml, mark_state - ''' - - model = self - - dic = dict( - width = model.width, - grid = model.grid, - k = model.k, - mult_arity = model.mult_arity, - base_fun_name = model.base_fun_name, - symbolic_enabled = model.symbolic_enabled, - affine_trainable = model.affine_trainable, - grid_eps = model.grid_eps, - grid_range = model.grid_range, - sp_trainable = model.sp_trainable, - sb_trainable = model.sb_trainable, - state_id = model.state_id, - auto_save = model.auto_save, - ckpt_path = model.ckpt_path, - round = model.round, - device = str(model.device) - ) - - for i in range (model.depth): - dic[f'symbolic.funs_name.{i}'] = model.symbolic_fun[i].funs_name - - with open(f'{path}_config.yml', 'w') as outfile: - yaml.dump(dic, outfile, default_flow_style=False) - - torch.save(model.state_dict(), f'{path}_state') - torch.save(model.cache_data, f'{path}_cache_data') - - @staticmethod - def loadckpt(path='model'): - ''' - load checkpoint from path - - Args: - ----- - path : str - the path where checkpoints are saved - - Returns: - -------- - MultKAN - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - >>> model.saveckpt('./mark') - >>> KAN.loadckpt('./mark') - ''' - with open(f'{path}_config.yml', 'r') as stream: - config = yaml.safe_load(stream) - - state = torch.load(f'{path}_state') - - model_load = MultKAN(width=config['width'], - grid=config['grid'], - k=config['k'], - mult_arity = config['mult_arity'], - base_fun=config['base_fun_name'], - symbolic_enabled=config['symbolic_enabled'], - affine_trainable=config['affine_trainable'], - grid_eps=config['grid_eps'], - grid_range=config['grid_range'], - sp_trainable=config['sp_trainable'], - sb_trainable=config['sb_trainable'], - state_id=config['state_id'], - auto_save=config['auto_save'], - first_init=False, - ckpt_path=config['ckpt_path'], - round = config['round']+1, - device = config['device']) - - model_load.load_state_dict(state) - model_load.cache_data = torch.load(f'{path}_cache_data') - - depth = len(model_load.width) - 1 - for l in range(depth): - out_dim = model_load.symbolic_fun[l].out_dim - in_dim = model_load.symbolic_fun[l].in_dim - funs_name = config[f'symbolic.funs_name.{l}'] - for j in range(out_dim): - for i in range(in_dim): - fun_name = funs_name[j][i] - model_load.symbolic_fun[l].funs_name[j][i] = fun_name - model_load.symbolic_fun[l].funs[j][i] = SYMBOLIC_LIB[fun_name][0] - model_load.symbolic_fun[l].funs_sympy[j][i] = SYMBOLIC_LIB[fun_name][1] - model_load.symbolic_fun[l].funs_avoid_singularity[j][i] = SYMBOLIC_LIB[fun_name][3] - return model_load - - def copy(self): - ''' - deepcopy - - Args: - ----- - path : str - the path where checkpoints are saved - - Returns: - -------- - MultKAN - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) - >>> model2 = model.copy() - >>> model2.act_fun[0].coef.data *= 2 - >>> print(model2.act_fun[0].coef.data) - >>> print(model.act_fun[0].coef.data) - ''' - path='copy_temp' - self.saveckpt(path) - return KAN.loadckpt(path) - - def rewind(self, model_id): - ''' - rewind to an old version - - Args: - ----- - model_id : str - in format '{a}.{b}' where a is the round number, b is the version number in that round - - Returns: - -------- - MultKAN - - Example - ------- - Please refer to tutorials. API 12: Checkpoint, save & load model - ''' - self.round += 1 - self.state_id = model_id.split('.')[-1] - - history_path = self.ckpt_path+'/history.txt' - with open(history_path, 'a') as file: - file.write(f'### Round {self.round} ###' + '\n') - - self.saveckpt(path=self.ckpt_path+'/'+f'{self.round}.{self.state_id}') - - print('rewind to model version '+f'{self.round-1}.{self.state_id}'+', renamed as '+f'{self.round}.{self.state_id}') - - return MultKAN.loadckpt(path=self.ckpt_path+'/'+str(model_id)) - - - def checkout(self, model_id): - ''' - check out an old version - - Args: - ----- - model_id : str - in format '{a}.{b}' where a is the round number, b is the version number in that round - - Returns: - -------- - MultKAN - - Example - ------- - Same use as rewind, although checkout doesn't change states - ''' - return MultKAN.loadckpt(path=self.ckpt_path+'/'+str(model_id)) - - def update_grid_from_samples(self, x): - ''' - update grid from samples - - Args: - ----- - x : 2D torch.tensor - inputs - - Returns: - -------- - None - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) - >>> print(model.act_fun[0].grid) - >>> x = torch.linspace(-10,10,steps=101)[:,None] - >>> model.update_grid_from_samples(x) - >>> print(model.act_fun[0].grid) - ''' - for l in range(self.depth): - self.get_act(x) - self.act_fun[l].update_grid_from_samples(self.acts[l]) - - def update_grid(self, x): - ''' - call update_grid_from_samples. This seems unnecessary but we retain it for the sake of classes that might inherit from MultKAN - ''' - self.update_grid_from_samples(x) - - def initialize_grid_from_another_model(self, model, x): - ''' - initialize grid from another model - - Args: - ----- - model : MultKAN - parent model - x : 2D torch.tensor - inputs - - Returns: - -------- - None - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) - >>> print(model.act_fun[0].grid) - >>> x = torch.linspace(-10,10,steps=101)[:,None] - >>> model2 = KAN(width=[1,1], grid=10, k=3, seed=0) - >>> model2.initialize_grid_from_another_model(model, x) - >>> print(model2.act_fun[0].grid) - ''' - model(x) - for l in range(self.depth): - self.act_fun[l].initialize_grid_from_parent(model.act_fun[l], model.acts[l]) - - def forward(self, x, singularity_avoiding=False, y_th=10.): - ''' - forward pass - - Args: - ----- - x : 2D torch.tensor - inputs - singularity_avoiding : bool - whether to avoid singularity for the symbolic branch - y_th : float - the threshold for singularity - - Returns: - -------- - None - - Example1 - -------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - >>> x = torch.rand(100,2) - >>> model(x).shape - - Example2 - -------- - >>> from kan import * - >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) - >>> x = torch.tensor([[1],[-0.01]]) - >>> model.fix_symbolic(0,0,0,'log',fit_params_bool=False) - >>> print(model(x)) - >>> print(model(x, singularity_avoiding=True)) - >>> print(model(x, singularity_avoiding=True, y_th=1.)) - ''' - x = x[:,self.input_id.long()] - assert x.shape[1] == self.width_in[0] - - # cache data - self.cache_data = x - - self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L]) - self.acts_premult = [] - self.spline_preacts = [] - self.spline_postsplines = [] - self.spline_postacts = [] - self.acts_scale = [] - self.acts_scale_spline = [] - self.subnode_actscale = [] - self.edge_actscale = [] - # self.neurons_scale = [] - - self.acts.append(x) # acts shape: (batch, width[l]) - - for l in range(self.depth): - - x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x) - #print(preacts, postacts_numerical, postspline) - - if self.symbolic_enabled == True: - x_symbolic, postacts_symbolic = self.symbolic_fun[l](x, singularity_avoiding=singularity_avoiding, y_th=y_th) - else: - x_symbolic = 0. - postacts_symbolic = 0. - - x = x_numerical + x_symbolic - - if self.save_act: - # save subnode_scale - self.subnode_actscale.append(torch.std(x, dim=0).detach()) - - # subnode affine transform - x = self.subnode_scale[l][None,:] * x + self.subnode_bias[l][None,:] - - if self.save_act: - postacts = postacts_numerical + postacts_symbolic - - # self.neurons_scale.append(torch.mean(torch.abs(x), dim=0)) - #grid_reshape = self.act_fun[l].grid.reshape(self.width_out[l + 1], self.width_in[l], -1) - input_range = torch.std(preacts, dim=0) + 0.1 - output_range_spline = torch.std(postacts_numerical, dim=0) # for training, only penalize the spline part - output_range = torch.std(postacts, dim=0) # for visualization, include the contribution from both spline + symbolic - # save edge_scale - self.edge_actscale.append(output_range) - - self.acts_scale.append((output_range / input_range).detach()) - self.acts_scale_spline.append(output_range_spline / input_range) - self.spline_preacts.append(preacts.detach()) - self.spline_postacts.append(postacts.detach()) - self.spline_postsplines.append(postspline.detach()) - - self.acts_premult.append(x.detach()) - - # multiplication - dim_sum = self.width[l+1][0] - dim_mult = self.width[l+1][1] - - if self.mult_homo == True: - for i in range(self.mult_arity-1): - if i == 0: - x_mult = x[:,dim_sum::self.mult_arity] * x[:,dim_sum+1::self.mult_arity] - else: - x_mult = x_mult * x[:,dim_sum+i+1::self.mult_arity] - - else: - for j in range(dim_mult): - acml_id = dim_sum + np.sum(self.mult_arity[l+1][:j]) - for i in range(self.mult_arity[l+1][j]-1): - if i == 0: - x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]] - else: - x_mult_j = x_mult_j * x[:,[acml_id+i+1]] - - if j == 0: - x_mult = x_mult_j - else: - x_mult = torch.cat([x_mult, x_mult_j], dim=1) - - if self.width[l+1][1] > 0: - x = torch.cat([x[:,:dim_sum], x_mult], dim=1) - - # x = x + self.biases[l].weight - # node affine transform - x = self.node_scale[l][None,:] * x + self.node_bias[l][None,:] - - self.acts.append(x.detach()) - - - return x - - def set_mode(self, l, i, j, mode, mask_n=None): - if mode == "s": - mask_n = 0.; - mask_s = 1. - elif mode == "n": - mask_n = 1.; - mask_s = 0. - elif mode == "sn" or mode == "ns": - if mask_n == None: - mask_n = 1. - else: - mask_n = mask_n - mask_s = 1. - else: - mask_n = 0.; - mask_s = 0. - - self.act_fun[l].mask.data[i][j] = mask_n - self.symbolic_fun[l].mask.data[j,i] = mask_s - - def fix_symbolic(self, l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True, random=False, log_history=True): - ''' - set (l,i,j) activation to be symbolic (specified by fun_name) - - Args: - ----- - l : int - layer index - i : int - input neuron index - j : int - output neuron index - fun_name : str - function name - fit_params_bool : bool - obtaining affine parameters through fitting (True) or setting default values (False) - a_range : tuple - sweeping range of a - b_range : tuple - sweeping range of b - verbose : bool - If True, more information is printed. - random : bool - initialize affine parameteres randomly or as [1,0,1,0] - log_history : bool - indicate whether to log history when the function is called - - Returns: - -------- - None or r2 (coefficient of determination) - - Example 1 - --------- - >>> # when fit_params_bool = False - >>> model = KAN(width=[2,5,1], grid=5, k=3) - >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False) - >>> print(model.act_fun[0].mask.reshape(2,5)) - >>> print(model.symbolic_fun[0].mask.reshape(2,5)) - - Example 2 - --------- - >>> # when fit_params_bool = True - >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.) - >>> x = torch.normal(0,1,size=(100,2)) - >>> model(x) # obtain activations (otherwise model does not have attributes acts) - >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True) - >>> print(model.act_fun[0].mask.reshape(2,5)) - >>> print(model.symbolic_fun[0].mask.reshape(2,5)) - ''' - if not fit_params_bool: - self.symbolic_fun[l].fix_symbolic(i, j, fun_name, verbose=verbose, random=random) - r2 = None - else: - x = self.acts[l][:, i] - mask = self.act_fun[l].mask - y = self.spline_postacts[l][:, j, i] - #y = self.postacts[l][:, j, i] - r2 = self.symbolic_fun[l].fix_symbolic(i, j, fun_name, x, y, a_range=a_range, b_range=b_range, verbose=verbose) - if mask[i,j] == 0: - r2 = - 1e8 - self.set_mode(l, i, j, mode="s") - - if log_history: - self.log_history('fix_symbolic') - return r2 - - def unfix_symbolic(self, l, i, j, log_history=True): - ''' - unfix the (l,i,j) activation function. - ''' - self.set_mode(l, i, j, mode="n") - self.symbolic_fun[l].funs_name[j][i] = "0" - if log_history: - self.log_history('unfix_symbolic') - - def unfix_symbolic_all(self, log_history=True): - ''' - unfix all activation functions. - ''' - for l in range(len(self.width) - 1): - for i in range(self.width_in[l]): - for j in range(self.width_out[l + 1]): - self.unfix_symbolic(l, i, j, log_history) - - def get_range(self, l, i, j, verbose=True): - ''' - Get the input range and output range of the (l,i,j) activation - - Args: - ----- - l : int - layer index - i : int - input neuron index - j : int - output neuron index - - Returns: - -------- - x_min : float - minimum of input - x_max : float - maximum of input - y_min : float - minimum of output - y_max : float - maximum of output - - Example - ------- - >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) - >>> x = torch.normal(0,1,size=(100,2)) - >>> model(x) # do a forward pass to obtain model.acts - >>> model.get_range(0,0,0) - ''' - x = self.spline_preacts[l][:, j, i] - y = self.spline_postacts[l][:, j, i] - x_min = torch.min(x).cpu().detach().numpy() - x_max = torch.max(x).cpu().detach().numpy() - y_min = torch.min(y).cpu().detach().numpy() - y_max = torch.max(y).cpu().detach().numpy() - if verbose: - print('x range: [' + '%.2f' % x_min, ',', '%.2f' % x_max, ']') - print('y range: [' + '%.2f' % y_min, ',', '%.2f' % y_max, ']') - return x_min, x_max, y_min, y_max - - def plot(self, folder="./figures", beta=3, metric='backward', scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None, varscale=1.0): - ''' - plot KAN - - Args: - ----- - folder : str - the folder to store pngs - beta : float - positive number. control the transparency of each activation. transparency = tanh(beta*l1). - mask : bool - If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions. - mode : bool - "supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean). - scale : float - control the size of the diagram - in_vars: None or list of str - the name(s) of input variables - out_vars: None or list of str - the name(s) of output variables - title: None or str - title - varscale : float - the size of input variables - - Returns: - -------- - Figure - - Example - ------- - >>> # see more interactive examples in demos - >>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0) - >>> x = torch.normal(0,1,size=(100,2)) - >>> model(x) # do a forward pass to obtain model.acts - >>> model.plot() - ''' - global Symbol - - if not self.save_act: - print('cannot plot since data are not saved. Set save_act=True first.') - - # forward to obtain activations - if self.acts == None: - if self.cache_data == None: - raise Exception('model hasn\'t seen any data yet.') - self.forward(self.cache_data) - - if metric == 'backward': - self.attribute() - - - if not os.path.exists(folder): - os.makedirs(folder) - # matplotlib.use('Agg') - depth = len(self.width) - 1 - for l in range(depth): - w_large = 2.0 - for i in range(self.width_in[l]): - for j in range(self.width_out[l+1]): - rank = torch.argsort(self.acts[l][:, i]) - fig, ax = plt.subplots(figsize=(w_large, w_large)) - - num = rank.shape[0] - - #print(self.width_in[l]) - #print(self.width_out[l+1]) - symbolic_mask = self.symbolic_fun[l].mask[j][i] - numeric_mask = self.act_fun[l].mask[i][j] - if symbolic_mask > 0. and numeric_mask > 0.: - color = 'purple' - alpha_mask = 1 - if symbolic_mask > 0. and numeric_mask == 0.: - color = "red" - alpha_mask = 1 - if symbolic_mask == 0. and numeric_mask > 0.: - color = "black" - alpha_mask = 1 - if symbolic_mask == 0. and numeric_mask == 0.: - color = "white" - alpha_mask = 0 - - - if tick == True: - ax.tick_params(axis="y", direction="in", pad=-22, labelsize=50) - ax.tick_params(axis="x", direction="in", pad=-15, labelsize=50) - x_min, x_max, y_min, y_max = self.get_range(l, i, j, verbose=False) - plt.xticks([x_min, x_max], ['%2.f' % x_min, '%2.f' % x_max]) - plt.yticks([y_min, y_max], ['%2.f' % y_min, '%2.f' % y_max]) - else: - plt.xticks([]) - plt.yticks([]) - if alpha_mask == 1: - plt.gca().patch.set_edgecolor('black') - else: - plt.gca().patch.set_edgecolor('white') - plt.gca().patch.set_linewidth(1.5) - # plt.axis('off') - - plt.plot(self.acts[l][:, i][rank].cpu().detach().numpy(), self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, lw=5) - if sample == True: - plt.scatter(self.acts[l][:, i][rank].cpu().detach().numpy(), self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, s=400 * scale ** 2) - plt.gca().spines[:].set_color(color) - - plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400) - plt.close() - - def score2alpha(score): - return np.tanh(beta * score) - - - if metric == 'forward_n': - scores = self.acts_scale - elif metric == 'forward_u': - scores = self.edge_actscale - elif metric == 'backward': - scores = self.edge_scores - else: - raise Exception(f'metric = \'{metric}\' not recognized') - - alpha = [score2alpha(score.cpu().detach().numpy()) for score in scores] - - # draw skeleton - width = np.array(self.width) - width_in = np.array(self.width_in) - width_out = np.array(self.width_out) - A = 1 - y0 = 0.3 # height: from input to pre-mult - z0 = 0.1 # height: from pre-mult to post-mult (input of next layer) - - neuron_depth = len(width) - min_spacing = A / np.maximum(np.max(width_out), 5) - - max_neuron = np.max(width_out) - max_num_weights = np.max(width_in[:-1] * width_out[1:]) - y1 = 0.4 / np.maximum(max_num_weights, 5) # size (height/width) of 1D function diagrams - y2 = 0.15 / np.maximum(max_neuron, 5) # size (height/width) of operations (sum and mult) - - fig, ax = plt.subplots(figsize=(10 * scale, 10 * scale * (neuron_depth - 1) * (y0+z0))) - # fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0)) - - # -- Transformation functions - DC_to_FC = ax.transData.transform - FC_to_NFC = fig.transFigure.inverted().transform - # -- Take data coordinates and transform them to normalized figure coordinates - DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x)) - - # plot scatters and lines - for l in range(neuron_depth): - - n = width_in[l] - - # scatters - for i in range(n): - plt.scatter(1 / (2 * n) + i / n, l * (y0+z0), s=min_spacing ** 2 * 10000 * scale ** 2, color='black') - - # plot connections (input to pre-mult) - for i in range(n): - if l < neuron_depth - 1: - n_next = width_out[l+1] - N = n * n_next - for j in range(n_next): - id_ = i * n_next + j - - symbol_mask = self.symbolic_fun[l].mask[j][i] - numerical_mask = self.act_fun[l].mask[i][j] - if symbol_mask == 1. and numerical_mask > 0.: - color = 'purple' - alpha_mask = 1. - if symbol_mask == 1. and numerical_mask == 0.: - color = "red" - alpha_mask = 1. - if symbol_mask == 0. and numerical_mask == 1.: - color = "black" - alpha_mask = 1. - if symbol_mask == 0. and numerical_mask == 0.: - color = "white" - alpha_mask = 0. - - plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N], [l * (y0+z0), l * (y0+z0) + y0/2 - y1], color=color, lw=2 * scale, alpha=alpha[l][j][i] * alpha_mask) - plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next], [l * (y0+z0) + y0/2 + y1, l * (y0+z0)+y0], color=color, lw=2 * scale, alpha=alpha[l][j][i] * alpha_mask) - - - # plot connections (pre-mult to post-mult, post-mult = next-layer input) - if l < neuron_depth - 1: - n_in = width_out[l+1] - n_out = width_in[l+1] - mult_id = 0 - for i in range(n_in): - if i < width[l+1][0]: - j = i - else: - if i == width[l+1][0]: - if isinstance(self.mult_arity,int): - ma = self.mult_arity - else: - ma = self.mult_arity[l+1][mult_id] - current_mult_arity = ma - if current_mult_arity == 0: - mult_id += 1 - if isinstance(self.mult_arity,int): - ma = self.mult_arity - else: - ma = self.mult_arity[l+1][mult_id] - current_mult_arity = ma - j = width[l+1][0] + mult_id - current_mult_arity -= 1 - #j = (i-width[l+1][0])//self.mult_arity + width[l+1][0] - plt.plot([1 / (2 * n_in) + i / n_in, 1 / (2 * n_out) + j / n_out], [l * (y0+z0) + y0, (l+1) * (y0+z0)], color='black', lw=2 * scale) - - - - plt.xlim(0, 1) - plt.ylim(-0.1 * (y0+z0), (neuron_depth - 1 + 0.1) * (y0+z0)) - - - plt.axis('off') - - for l in range(neuron_depth - 1): - # plot splines - n = width_in[l] - for i in range(n): - n_next = width_out[l + 1] - N = n * n_next - for j in range(n_next): - id_ = i * n_next + j - im = plt.imread(f'{folder}/sp_{l}_{i}_{j}.png') - left = DC_to_NFC([1 / (2 * N) + id_ / N - y1, 0])[0] - right = DC_to_NFC([1 / (2 * N) + id_ / N + y1, 0])[0] - bottom = DC_to_NFC([0, l * (y0+z0) + y0/2 - y1])[1] - up = DC_to_NFC([0, l * (y0+z0) + y0/2 + y1])[1] - newax = fig.add_axes([left, bottom, right - left, up - bottom]) - # newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE') - newax.imshow(im, alpha=alpha[l][j][i]) - newax.axis('off') - - - # plot sum symbols - N = n = width_out[l+1] - for j in range(n): - id_ = j - path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png" - im = plt.imread(path) - left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0] - right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0] - bottom = DC_to_NFC([0, l * (y0+z0) + y0 - y2])[1] - up = DC_to_NFC([0, l * (y0+z0) + y0 + y2])[1] - newax = fig.add_axes([left, bottom, right - left, up - bottom]) - newax.imshow(im) - newax.axis('off') - - # plot mult symbols - N = n = width_in[l+1] - n_sum = width[l+1][0] - n_mult = width[l+1][1] - for j in range(n_mult): - id_ = j + n_sum - path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png" - im = plt.imread(path) - left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0] - right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0] - bottom = DC_to_NFC([0, (l+1) * (y0+z0) - y2])[1] - up = DC_to_NFC([0, (l+1) * (y0+z0) + y2])[1] - newax = fig.add_axes([left, bottom, right - left, up - bottom]) - newax.imshow(im) - newax.axis('off') - - if in_vars != None: - n = self.width_in[0] - for i in range(n): - if isinstance(in_vars[i], sympy.Expr): - plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, f'${latex(in_vars[i])}$', fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center') - else: - plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[i], fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center') - - - - if out_vars != None: - n = self.width_in[-1] - for i in range(n): - if isinstance(out_vars[i], sympy.Expr): - plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0+z0) * (len(self.width) - 1) + 0.15, f'${latex(out_vars[i])}$', fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center') - else: - plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0+z0) * (len(self.width) - 1) + 0.15, out_vars[i], fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center') - - if title != None: - plt.gcf().get_axes()[0].text(0.5, (y0+z0) * (len(self.width) - 1) + 0.3, title, fontsize=40 * scale, horizontalalignment='center', verticalalignment='center') - - - def reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff): - ''' - Get regularization - - Args: - ----- - reg_metric : the regularization metric - 'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward' - lamb_l1 : float - l1 penalty strength - lamb_entropy : float - entropy penalty strength - lamb_coef : float - coefficient penalty strength - lamb_coefdiff : float - coefficient smoothness strength - - Returns: - -------- - reg_ : torch.float - - Example - ------- - >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) - >>> x = torch.rand(100,2) - >>> model.get_act(x) - >>> model.reg('edge_forward_spline_n', 1.0, 2.0, 1.0, 1.0) - ''' - if reg_metric == 'edge_forward_spline_n': - acts_scale = self.acts_scale_spline - - elif reg_metric == 'edge_forward_sum': - acts_scale = self.acts_scale - - elif reg_metric == 'edge_forward_spline_u': - acts_scale = self.edge_actscale - - elif reg_metric == 'edge_backward': - acts_scale = self.edge_scores - - elif reg_metric == 'node_backward': - acts_scale = self.node_attribute_scores - - else: - raise Exception(f'reg_metric = {reg_metric} not recognized!') - - reg_ = 0. - for i in range(len(acts_scale)): - vec = acts_scale[i] - - l1 = torch.sum(vec) - p_row = vec / (torch.sum(vec, dim=1, keepdim=True) + 1) - p_col = vec / (torch.sum(vec, dim=0, keepdim=True) + 1) - entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1)) - entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0)) - reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col) # both l1 and entropy - - # regularize coefficient to encourage spline to be zero - for i in range(len(self.act_fun)): - coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1)) - coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)), dim=1)) - reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1 - - return reg_ - - def get_reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff): - ''' - Get regularization. This seems unnecessary but in case a class wants to inherit this, it may want to rewrite get_reg, but not reg. - ''' - return self.reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff) - - def disable_symbolic_in_fit(self, lamb): - ''' - during fitting, disable symbolic if either is true (lamb = 0, none of symbolic functions is active) - ''' - old_save_act = self.save_act - if lamb == 0.: - self.save_act = False - - # skip symbolic if no symbolic is turned on - depth = len(self.symbolic_fun) - no_symbolic = True - for l in range(depth): - no_symbolic *= torch.sum(torch.abs(self.symbolic_fun[l].mask)) == 0 - - old_symbolic_enabled = self.symbolic_enabled - - if no_symbolic: - self.symbolic_enabled = False - - return old_save_act, old_symbolic_enabled - - def get_params(self): - ''' - Get parameters - ''' - return self.parameters() - - - def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1.,start_grid_update_step=-1, stop_grid_update_step=50, batch=-1, - metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None): - ''' - training - - Args: - ----- - dataset : dic - contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label'] - opt : str - "LBFGS" or "Adam" - steps : int - training steps - log : int - logging frequency - lamb : float - overall penalty strength - lamb_l1 : float - l1 penalty strength - lamb_entropy : float - entropy penalty strength - lamb_coef : float - coefficient magnitude penalty strength - lamb_coefdiff : float - difference of nearby coefficits (smoothness) penalty strength - update_grid : bool - If True, update grid regularly before stop_grid_update_step - grid_update_num : int - the number of grid updates before stop_grid_update_step - start_grid_update_step : int - no grid updates before this training step - stop_grid_update_step : int - no grid updates after this training step - loss_fn : function - loss function - lr : float - learning rate - batch : int - batch size, if -1 then full. - save_fig_freq : int - save figure every (save_fig_freq) steps - singularity_avoiding : bool - indicate whether to avoid singularity for the symbolic part - y_th : float - singularity threshold (anything above the threshold is considered singular and is softened in some ways) - reg_metric : str - regularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'} - metrics : a list of metrics (as functions) - the metrics to be computed in training - display_metrics : a list of functions - the metric to be displayed in tqdm progress bar - - Returns: - -------- - results : dic - results['train_loss'], 1D array of training losses (RMSE) - results['test_loss'], 1D array of test losses (RMSE) - results['reg'], 1D array of regularization - other metrics specified in metrics - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=2) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.plot() - # Most examples in toturals involve the fit() method. Please check them for useness. - ''' - - if lamb > 0. and not self.save_act: - print('setting lamb=0. If you want to set lamb > 0, set self.save_act=True') - - old_save_act, old_symbolic_enabled = self.disable_symbolic_in_fit(lamb) - - pbar = tqdm(range(steps), desc='description', ncols=100) - - if loss_fn == None: - loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2) - else: - loss_fn = loss_fn_eval = loss_fn - - grid_update_freq = int(stop_grid_update_step / grid_update_num) - - if opt == "Adam": - optimizer = torch.optim.Adam(self.get_params(), lr=lr) - elif opt == "LBFGS": - optimizer = LBFGS(self.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) - - results = {} - results['train_loss'] = [] - results['test_loss'] = [] - results['reg'] = [] - if metrics != None: - for i in range(len(metrics)): - results[metrics[i].__name__] = [] - - if batch == -1 or batch > dataset['train_input'].shape[0]: - batch_size = dataset['train_input'].shape[0] - batch_size_test = dataset['test_input'].shape[0] - else: - batch_size = batch - batch_size_test = batch - - global train_loss, reg_ - - def closure(): - global train_loss, reg_ - optimizer.zero_grad() - pred = self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding, y_th=y_th) - train_loss = loss_fn(pred, dataset['train_label'][train_id]) - if self.save_act: - if reg_metric == 'edge_backward': - self.attribute() - if reg_metric == 'node_backward': - self.node_attribute() - reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff) - else: - reg_ = torch.tensor(0.) - objective = train_loss + lamb * reg_ - objective.backward() - return objective - - if save_fig: - if not os.path.exists(img_folder): - os.makedirs(img_folder) - - for _ in pbar: - - if _ == steps-1 and old_save_act: - self.save_act = True - - if save_fig and _ % save_fig_freq == 0: - save_act = self.save_act - self.save_act = True - - train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False) - test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False) - - if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ >= start_grid_update_step: - self.update_grid(dataset['train_input'][train_id]) - - if opt == "LBFGS": - optimizer.step(closure) - - if opt == "Adam": - pred = self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding, y_th=y_th) - train_loss = loss_fn(pred, dataset['train_label'][train_id]) - if self.save_act: - if reg_metric == 'edge_backward': - self.attribute() - if reg_metric == 'node_backward': - self.node_attribute() - reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff) - else: - reg_ = torch.tensor(0.) - loss = train_loss + lamb * reg_ - optimizer.zero_grad() - loss.backward() - optimizer.step() - - test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id]), dataset['test_label'][test_id]) - - - if metrics != None: - for i in range(len(metrics)): - results[metrics[i].__name__].append(metrics[i]().item()) - - results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy()) - results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy()) - results['reg'].append(reg_.cpu().detach().numpy()) - - if _ % log == 0: - if display_metrics == None: - pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy())) - else: - string = '' - data = () - for metric in display_metrics: - string += f' {metric}: %.2e |' - try: - results[metric] - except: - raise Exception(f'{metric} not recognized') - data += (results[metric][-1],) - pbar.set_description(string % data) - - - if save_fig and _ % save_fig_freq == 0: - self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta) - plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=200) - plt.close() - self.save_act = save_act - - self.log_history('fit') - # revert back to original state - self.symbolic_enabled = old_symbolic_enabled - return results - - def prune_node(self, threshold=1e-2, mode="auto", active_neurons_id=None, log_history=True): - ''' - pruning nodes - - Args: - ----- - threshold : float - if the attribution score of a neuron is below the threshold, it is considered dead and will be removed - mode : str - 'auto' or 'manual'. with 'auto', nodes are automatically pruned using threshold. with 'manual', active_neurons_id should be passed in. - - Returns: - -------- - pruned network : MultKAN - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=2) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model = model.prune_node() - >>> model.plot() - ''' - if self.acts == None: - self.get_act() - - mask_up = [torch.ones(self.width_in[0], device=self.device)] - mask_down = [] - active_neurons_up = [list(range(self.width_in[0]))] - active_neurons_down = [] - num_sums = [] - num_mults = [] - mult_arities = [[]] - - if active_neurons_id != None: - mode = "manual" - - for i in range(len(self.acts_scale) - 1): - - mult_arity = [] - - if mode == "auto": - self.attribute() - overall_important_up = self.node_scores[i+1] > threshold - - elif mode == "manual": - overall_important_up = torch.zeros(self.width_in[i + 1], dtype=torch.bool, device=self.device) - overall_important_up[active_neurons_id[i]] = True - - - num_sum = torch.sum(overall_important_up[:self.width[i+1][0]]) - num_mult = torch.sum(overall_important_up[self.width[i+1][0]:]) - if self.mult_homo == True: - overall_important_down = torch.cat([overall_important_up[:self.width[i+1][0]], (overall_important_up[self.width[i+1][0]:][None,:].expand(self.mult_arity,-1)).T.reshape(-1,)], dim=0) - else: - overall_important_down = overall_important_up[:self.width[i+1][0]] - for j in range(overall_important_up[self.width[i+1][0]:].shape[0]): - active_bool = overall_important_up[self.width[i+1][0]+j] - arity = self.mult_arity[i+1][j] - overall_important_down = torch.cat([overall_important_down, torch.tensor([active_bool]*arity).to(self.device)]) - if active_bool: - mult_arity.append(arity) - - num_sums.append(num_sum.item()) - num_mults.append(num_mult.item()) - - mask_up.append(overall_important_up.float()) - mask_down.append(overall_important_down.float()) - - active_neurons_up.append(torch.where(overall_important_up == True)[0]) - active_neurons_down.append(torch.where(overall_important_down == True)[0]) - - mult_arities.append(mult_arity) - - active_neurons_down.append(list(range(self.width_out[-1]))) - mask_down.append(torch.ones(self.width_out[-1], device=self.device)) - - if self.mult_homo == False: - mult_arities.append(self.mult_arity[-1]) - - self.mask_up = mask_up - self.mask_down = mask_down - - # update act_fun[l].mask up - for l in range(len(self.acts_scale) - 1): - for i in range(self.width_in[l + 1]): - if i not in active_neurons_up[l + 1]: - self.remove_node(l + 1, i, mode='up',log_history=False) - - for i in range(self.width_out[l + 1]): - if i not in active_neurons_down[l]: - self.remove_node(l + 1, i, mode='down',log_history=False) - - model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun_name, mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=True, first_init=False, state_id=self.state_id, round=self.round).to(self.device) - model2.load_state_dict(self.state_dict()) - - width_new = [self.width[0]] - - for i in range(len(self.acts_scale)): - - if i < len(self.acts_scale) - 1: - num_sum = num_sums[i] - num_mult = num_mults[i] - model2.node_bias[i].data = model2.node_bias[i].data[active_neurons_up[i+1]] - model2.node_scale[i].data = model2.node_scale[i].data[active_neurons_up[i+1]] - model2.subnode_bias[i].data = model2.subnode_bias[i].data[active_neurons_down[i]] - model2.subnode_scale[i].data = model2.subnode_scale[i].data[active_neurons_down[i]] - model2.width[i+1] = [num_sum, num_mult] - - model2.act_fun[i].out_dim_sum = num_sum - model2.act_fun[i].out_dim_mult = num_mult - - model2.symbolic_fun[i].out_dim_sum = num_sum - model2.symbolic_fun[i].out_dim_mult = num_mult - - width_new.append([num_sum, num_mult]) - - model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i]) - model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i]) - - model2.cache_data = self.cache_data - model2.acts = None - - width_new.append(self.width[-1]) - model2.width = width_new - - if self.mult_homo == False: - model2.mult_arity = mult_arities - - if log_history: - self.log_history('prune_node') - model2.state_id += 1 - - return model2 - - def prune_edge(self, threshold=3e-2, log_history=True): - ''' - pruning edges - - Args: - ----- - threshold : float - if the attribution score of an edge is below the threshold, it is considered dead and will be set to zero. - - Returns: - -------- - pruned network : MultKAN - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=2) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model = model.prune_edge() - >>> model.plot() - ''' - if self.acts == None: - self.get_act() - - for i in range(len(self.width)-1): - #self.act_fun[i].mask.data = ((self.acts_scale[i] > threshold).permute(1,0)).float() - old_mask = self.act_fun[i].mask.data - self.act_fun[i].mask.data = ((self.edge_scores[i] > threshold).permute(1,0)*old_mask).float() - - if log_history: - self.log_history('fix_symbolic') - - def prune(self, node_th=1e-2, edge_th=3e-2): - ''' - prune (both nodes and edges) - - Args: - ----- - node_th : float - if the attribution score of a node is below node_th, it is considered dead and will be set to zero. - edge_th : float - if the attribution score of an edge is below node_th, it is considered dead and will be set to zero. - - Returns: - -------- - pruned network : MultKAN - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=2) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model = model.prune() - >>> model.plot() - ''' - if self.acts == None: - self.get_act() - - self = self.prune_node(node_th, log_history=False) - #self.prune_node(node_th, log_history=False) - self.forward(self.cache_data) - self.attribute() - self.prune_edge(edge_th, log_history=False) - self.log_history('prune') - return self - - def prune_input(self, threshold=1e-2, active_inputs=None, log_history=True): - ''' - prune inputs - - Args: - ----- - threshold : float - if the attribution score of the input feature is below threshold, it is considered irrelevant. - active_inputs : None or list - if a list is passed, the manual mode will disregard attribution score and prune as instructed. - - Returns: - -------- - pruned network : MultKAN - - Example1 - -------- - >>> # automatic - >>> from kan import * - >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.plot() - >>> model = model.prune_input() - >>> model.plot() - - Example2 - -------- - >>> # automatic - >>> from kan import * - >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.plot() - >>> model = model.prune_input(active_inputs=[0,1]) - >>> model.plot() - ''' - if active_inputs == None: - self.attribute() - input_score = self.node_scores[0] - input_mask = input_score > threshold - print('keep:', input_mask.tolist()) - input_id = torch.where(input_mask==True)[0] - - else: - input_id = torch.tensor(active_inputs, dtype=torch.long).to(self.device) - - model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun, mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=True, first_init=False, state_id=self.state_id, round=self.round).to(self.device) - model2.load_state_dict(self.state_dict()) - - model2.act_fun[0] = model2.act_fun[0].get_subset(input_id, torch.arange(self.width_out[1])) - model2.symbolic_fun[0] = self.symbolic_fun[0].get_subset(input_id, torch.arange(self.width_out[1])) - - model2.cache_data = self.cache_data - model2.acts = None - - model2.width[0] = [len(input_id), 0] - model2.input_id = input_id - - if log_history: - self.log_history('prune_input') - model2.state_id += 1 - - return model2 - - def remove_edge(self, l, i, j, log_history=True): - ''' - remove activtion phi(l,i,j) (set its mask to zero) - ''' - self.act_fun[l].mask[i][j] = 0. - if log_history: - self.log_history('remove_edge') - - def remove_node(self, l ,i, mode='all', log_history=True): - ''' - remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero) - ''' - if mode == 'down': - self.act_fun[l - 1].mask[:, i] = 0. - self.symbolic_fun[l - 1].mask[i, :] *= 0. - - elif mode == 'up': - self.act_fun[l].mask[i, :] = 0. - self.symbolic_fun[l].mask[:, i] *= 0. - - else: - self.remove_node(l, i, mode='up') - self.remove_node(l, i, mode='down') - - if log_history: - self.log_history('remove_node') - - - def attribute(self, l=None, i=None, out_score=None, plot=True): - ''' - get attribution scores - - Args: - ----- - l : None or int - layer index - i : None or int - neuron index - out_score : None or 1D torch.float - specify output scores - plot : bool - when plot = True, display the bar show - - Returns: - -------- - attribution scores - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.attribute() - >>> model.feature_score - ''' - # output (out_dim, in_dim) - - if l != None: - self.attribute() - out_score = self.node_scores[l] - - if self.acts == None: - self.get_act() - - def score_node2subnode(node_score, width, mult_arity, out_dim): - - assert np.sum(width) == node_score.shape[1] - if isinstance(mult_arity, int): - n_subnode = width[0] + mult_arity * width[1] - else: - n_subnode = width[0] + int(np.sum(mult_arity)) - - #subnode_score_leaf = torch.zeros(out_dim, n_subnode).requires_grad_(True) - #subnode_score = subnode_score_leaf.clone() - #subnode_score[:,:width[0]] = node_score[:,:width[0]] - subnode_score = node_score[:,:width[0]] - if isinstance(mult_arity, int): - #subnode_score[:,width[0]:] = node_score[:,width[0]:][:,:,None].expand(out_dim, node_score[width[0]:].shape[0], mult_arity).reshape(out_dim,-1) - subnode_score = torch.cat([subnode_score, node_score[:,width[0]:][:,:,None].expand(out_dim, node_score[:,width[0]:].shape[1], mult_arity).reshape(out_dim,-1)], dim=1) - else: - acml = width[0] - for i in range(len(mult_arity)): - #subnode_score[:, acml:acml+mult_arity[i]] = node_score[:, width[0]+i] - subnode_score = torch.cat([subnode_score, node_score[:, width[0]+i].expand(out_dim, mult_arity[i])], dim=1) - acml += mult_arity[i] - return subnode_score - - - node_scores = [] - subnode_scores = [] - edge_scores = [] - - l_query = l - if l == None: - l_end = self.depth - else: - l_end = l - - # back propagate from the queried layer - out_dim = self.width_in[l_end] - if out_score == None: - node_score = torch.eye(out_dim).requires_grad_(True) - else: - node_score = torch.diag(out_score).requires_grad_(True) - node_scores.append(node_score) - - device = self.act_fun[0].grid.device - - for l in range(l_end,0,-1): - - # node to subnode - if isinstance(self.mult_arity, int): - subnode_score = score_node2subnode(node_score, self.width[l], self.mult_arity, out_dim=out_dim) - else: - mult_arity = self.mult_arity[l] - #subnode_score = score_node2subnode(node_score, self.width[l], mult_arity) - subnode_score = score_node2subnode(node_score, self.width[l], mult_arity, out_dim=out_dim) - - subnode_scores.append(subnode_score) - # subnode to edge - #print(self.edge_actscale[l-1].device, subnode_score.device, self.subnode_actscale[l-1].device) - edge_score = torch.einsum('ij,ki,i->kij', self.edge_actscale[l-1], subnode_score.to(device), 1/(self.subnode_actscale[l-1]+1e-4)) - edge_scores.append(edge_score) - - # edge to node - node_score = torch.sum(edge_score, dim=1) - node_scores.append(node_score) - - self.node_scores_all = list(reversed(node_scores)) - self.edge_scores_all = list(reversed(edge_scores)) - self.subnode_scores_all = list(reversed(subnode_scores)) - - self.node_scores = [torch.mean(l, dim=0) for l in self.node_scores_all] - self.edge_scores = [torch.mean(l, dim=0) for l in self.edge_scores_all] - self.subnode_scores = [torch.mean(l, dim=0) for l in self.subnode_scores_all] - - # return - if l_query != None: - if i == None: - return self.node_scores_all[0] - else: - - # plot - if plot: - in_dim = self.width_in[0] - plt.figure(figsize=(1*in_dim, 3)) - plt.bar(range(in_dim),self.node_scores_all[0][i].cpu().detach().numpy()) - plt.xticks(range(in_dim)); - - return self.node_scores_all[0][i] - - def node_attribute(self): - self.node_attribute_scores = [] - for l in range(1, self.depth+1): - node_attr = self.attribute(l) - self.node_attribute_scores.append(node_attr) - - def feature_interaction(self, l, neuron_th = 1e-2, feature_th = 1e-2): - ''' - get feature interaction - - Args: - ----- - l : int - layer index - neuron_th : float - threshold to determine whether a neuron is active - feature_th : float - threshold to determine whether a feature is active - - Returns: - -------- - dictionary - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.attribute() - >>> model.feature_interaction(1) - ''' - dic = {} - width = self.width_in[l] - - for i in range(width): - score = self.attribute(l,i,plot=False) - - if torch.max(score) > neuron_th: - features = tuple(torch.where(score > torch.max(score) * feature_th)[0].detach().numpy()) - if features in dic.keys(): - dic[features] += 1 - else: - dic[features] = 1 - - return dic - - def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True, r2_loss_fun=lambda x: np.log2(1+1e-5-x), c_loss_fun=lambda x: x, weight_simple = 0.8): - ''' - suggest symbolic function - - Args: - ----- - l : int - layer index - i : int - neuron index in layer l - j : int - neuron index in layer j - a_range : tuple - search range of a - b_range : tuple - search range of b - lib : list of str - library of candidate symbolic functions - topk : int - the number of top functions displayed - verbose : bool - if verbose = True, print more information - r2_loss_fun : functoon - function : r2 -> "bits" - c_loss_fun : fun - function : c -> 'bits' - weight_simple : float - the simplifty weight: the higher, more prefer simplicity over performance - - - Returns: - -------- - best_name (str), best_fun (function), best_r2 (float), best_c (float) - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.suggest_symbolic(0,1,0) - ''' - r2s = [] - cs = [] - - if lib == None: - symbolic_lib = SYMBOLIC_LIB - else: - symbolic_lib = {} - for item in lib: - symbolic_lib[item] = SYMBOLIC_LIB[item] - - # getting r2 and complexities - for (name, content) in symbolic_lib.items(): - r2 = self.fix_symbolic(l, i, j, name, a_range=a_range, b_range=b_range, verbose=False, log_history=False) - if r2 == -1e8: # zero function - r2s.append(-1e8) - else: - r2s.append(r2.item()) - self.unfix_symbolic(l, i, j, log_history=False) - c = content[2] - cs.append(c) - - r2s = np.array(r2s) - cs = np.array(cs) - r2_loss = r2_loss_fun(r2s).astype('float') - cs_loss = c_loss_fun(cs) - - loss = weight_simple * cs_loss + (1-weight_simple) * r2_loss - - sorted_ids = np.argsort(loss)[:topk] - r2s = r2s[sorted_ids][:topk] - cs = cs[sorted_ids][:topk] - r2_loss = r2_loss[sorted_ids][:topk] - cs_loss = cs_loss[sorted_ids][:topk] - loss = loss[sorted_ids][:topk] - - topk = np.minimum(topk, len(symbolic_lib)) - - if verbose == True: - # print results in a dataframe - results = {} - results['function'] = [list(symbolic_lib.items())[sorted_ids[i]][0] for i in range(topk)] - results['fitting r2'] = r2s[:topk] - results['r2 loss'] = r2_loss[:topk] - results['complexity'] = cs[:topk] - results['complexity loss'] = cs_loss[:topk] - results['total loss'] = loss[:topk] - - df = pd.DataFrame(results) - print(df) - - best_name = list(symbolic_lib.items())[sorted_ids[0]][0] - best_fun = list(symbolic_lib.items())[sorted_ids[0]][1] - best_r2 = r2s[0] - best_c = cs[0] - - return best_name, best_fun, best_r2, best_c; - - def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple = 0.8, r2_threshold=0.0): - ''' - automatic symbolic regression for all edges - - Args: - ----- - a_range : tuple - search range of a - b_range : tuple - search range of b - lib : list of str - library of candidate symbolic functions - verbose : int - larger verbosity => more verbosity - weight_simple : float - a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity - r2_threshold : float - If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold - Returns: - -------- - None - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.auto_symbolic() - ''' - for l in range(len(self.width_in) - 1): - for i in range(self.width_in[l]): - for j in range(self.width_out[l + 1]): - if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.: - print(f'skipping ({l},{i},{j}) since already symbolic') - elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.: - self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False) - print(f'fixing ({l},{i},{j}) with 0') - else: - name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False, weight_simple=weight_simple) - if r2 >= r2_threshold: - self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False) - if verbose >= 1: - print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}') - else: - print(f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.') - - self.log_history('auto_symbolic') - - def symbolic_formula(self, var=None, normalizer=None, output_normalizer = None): - ''' - get symbolic formula - - Args: - ----- - var : None or a list of sympy expression - input variables - normalizer : [mean, std] - output_normalizer : [mean, std] - - Returns: - -------- - None - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.auto_symbolic() - >>> model.symbolic_formula()[0][0] - ''' - - symbolic_acts = [] - symbolic_acts_premult = [] - x = [] - - def ex_round(ex1, n_digit): - ex2 = ex1 - for a in sympy.preorder_traversal(ex1): - if isinstance(a, sympy.Float): - ex2 = ex2.subs(a, round(a, n_digit)) - return ex2 - - # define variables - if var == None: - for ii in range(1, self.width[0][0] + 1): - exec(f"x{ii} = sympy.Symbol('x_{ii}')") - exec(f"x.append(x{ii})") - elif isinstance(var[0], sympy.Expr): - x = var - else: - x = [sympy.symbols(var_) for var_ in var] - - x0 = x - - if normalizer != None: - mean = normalizer[0] - std = normalizer[1] - x = [(x[i] - mean[i]) / std[i] for i in range(len(x))] - - symbolic_acts.append(x) - - for l in range(len(self.width_in) - 1): - num_sum = self.width[l + 1][0] - num_mult = self.width[l + 1][1] - y = [] - for j in range(self.width_out[l + 1]): - yj = 0. - for i in range(self.width_in[l]): - a, b, c, d = self.symbolic_fun[l].affine[j, i] - sympy_fun = self.symbolic_fun[l].funs_sympy[j][i] - try: - yj += c * sympy_fun(a * x[i] + b) + d - except: - print('make sure all activations need to be converted to symbolic formulas first!') - return - yj = self.subnode_scale[l][j] * yj + self.subnode_bias[l][j] - if simplify == True: - y.append(sympy.simplify(yj)) - else: - y.append(yj) - - symbolic_acts_premult.append(y) - - mult = [] - for k in range(num_mult): - if isinstance(self.mult_arity, int): - mult_arity = self.mult_arity - else: - mult_arity = self.mult_arity[l+1][k] - for i in range(mult_arity-1): - if i == 0: - mult_k = y[num_sum+2*k] * y[num_sum+2*k+1] - else: - mult_k = mult_k * y[num_sum+2*k+i+1] - mult.append(mult_k) - - y = y[:num_sum] + mult - - for j in range(self.width_in[l+1]): - y[j] = self.node_scale[l][j] * y[j] + self.node_bias[l][j] - - x = y - symbolic_acts.append(x) - - if output_normalizer != None: - output_layer = symbolic_acts[-1] - means = output_normalizer[0] - stds = output_normalizer[1] - - assert len(output_layer) == len(means), 'output_normalizer does not match the output layer' - assert len(output_layer) == len(stds), 'output_normalizer does not match the output layer' - - output_layer = [(output_layer[i] * stds[i] + means[i]) for i in range(len(output_layer))] - symbolic_acts[-1] = output_layer - - - self.symbolic_acts = [[symbolic_acts[l][i] for i in range(len(symbolic_acts[l]))] for l in range(len(symbolic_acts))] - self.symbolic_acts_premult = [[symbolic_acts_premult[l][i] for i in range(len(symbolic_acts_premult[l]))] for l in range(len(symbolic_acts_premult))] - - out_dim = len(symbolic_acts[-1]) - #return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0 - - if simplify: - return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0 - else: - return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0 - - - def expand_depth(self): - ''' - expand network depth, add an indentity layer to the end. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb. - - Args: - ----- - var : None or a list of sympy expression - input variables - normalizer : [mean, std] - output_normalizer : [mean, std] - - Returns: - -------- - None - ''' - self.depth += 1 - - # add kanlayer, set mask to zero - dim_out = self.width_in[-1] - layer = KANLayer(dim_out, dim_out, num=self.grid, k=self.k) - layer.mask *= 0. - self.act_fun.append(layer) - - self.width.append([dim_out, 0]) - self.mult_arity.append([]) - - # add symbolic_kanlayer set mask to one. fun = identity on diagonal and zero for off-diagonal - layer = Symbolic_KANLayer(dim_out, dim_out) - layer.mask += 1. - - for j in range(dim_out): - for i in range(dim_out): - if i == j: - layer.fix_symbolic(i,j,'x') - else: - layer.fix_symbolic(i,j,'0') - - self.symbolic_fun.append(layer) - - self.node_bias.append(torch.nn.Parameter(torch.zeros(dim_out,device=self.device)).requires_grad_(self.affine_trainable)) - self.node_scale.append(torch.nn.Parameter(torch.ones(dim_out,device=self.device)).requires_grad_(self.affine_trainable)) - self.subnode_bias.append(torch.nn.Parameter(torch.zeros(dim_out,device=self.device)).requires_grad_(self.affine_trainable)) - self.subnode_scale.append(torch.nn.Parameter(torch.ones(dim_out,device=self.device)).requires_grad_(self.affine_trainable)) - - def expand_width(self, layer_id, n_added_nodes, sum_bool=True, mult_arity=2): - ''' - expand network width. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb. - - Args: - ----- - layer_id : int - layer index - n_added_nodes : init - the number of added nodes - sum_bool : bool - if sum_bool == True, added nodes are addition nodes; otherwise multiplication nodes - mult_arity : init - multiplication arity (the number of numbers to be multiplied) - - Returns: - -------- - None - ''' - def _expand(layer_id, n_added_nodes, sum_bool=True, mult_arity=2, added_dim='out'): - l = layer_id - in_dim = self.symbolic_fun[l].in_dim - out_dim = self.symbolic_fun[l].out_dim - if sum_bool: - - if added_dim == 'out': - new = Symbolic_KANLayer(in_dim, out_dim + n_added_nodes) - old = self.symbolic_fun[l] - in_id = np.arange(in_dim) - out_id = np.arange(out_dim + n_added_nodes) - - for j in out_id: - for i in in_id: - new.fix_symbolic(i,j,'0') - new.mask += 1. - - for j in out_id: - for i in in_id: - if j > n_added_nodes-1: - new.funs[j][i] = old.funs[j-n_added_nodes][i] - new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j-n_added_nodes][i] - new.funs_sympy[j][i] = old.funs_sympy[j-n_added_nodes][i] - new.funs_name[j][i] = old.funs_name[j-n_added_nodes][i] - new.affine.data[j][i] = old.affine.data[j-n_added_nodes][i] - - self.symbolic_fun[l] = new - self.act_fun[l] = KANLayer(in_dim, out_dim + n_added_nodes, num=self.grid, k=self.k) - self.act_fun[l].mask *= 0. - - self.node_scale[l].data = torch.cat([torch.ones(n_added_nodes, device=self.device), self.node_scale[l].data]) - self.node_bias[l].data = torch.cat([torch.zeros(n_added_nodes, device=self.device), self.node_bias[l].data]) - self.subnode_scale[l].data = torch.cat([torch.ones(n_added_nodes, device=self.device), self.subnode_scale[l].data]) - self.subnode_bias[l].data = torch.cat([torch.zeros(n_added_nodes, device=self.device), self.subnode_bias[l].data]) - - - - if added_dim == 'in': - new = Symbolic_KANLayer(in_dim + n_added_nodes, out_dim) - old = self.symbolic_fun[l] - in_id = np.arange(in_dim + n_added_nodes) - out_id = np.arange(out_dim) - - for j in out_id: - for i in in_id: - new.fix_symbolic(i,j,'0') - new.mask += 1. - - for j in out_id: - for i in in_id: - if i > n_added_nodes-1: - new.funs[j][i] = old.funs[j][i-n_added_nodes] - new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i-n_added_nodes] - new.funs_sympy[j][i] = old.funs_sympy[j][i-n_added_nodes] - new.funs_name[j][i] = old.funs_name[j][i-n_added_nodes] - new.affine.data[j][i] = old.affine.data[j][i-n_added_nodes] - - self.symbolic_fun[l] = new - self.act_fun[l] = KANLayer(in_dim + n_added_nodes, out_dim, num=self.grid, k=self.k) - self.act_fun[l].mask *= 0. - - - else: - - if isinstance(mult_arity, int): - mult_arity = [mult_arity] * n_added_nodes - - if added_dim == 'out': - n_added_subnodes = np.sum(mult_arity) - new = Symbolic_KANLayer(in_dim, out_dim + n_added_subnodes) - old = self.symbolic_fun[l] - in_id = np.arange(in_dim) - out_id = np.arange(out_dim + n_added_nodes) - - for j in out_id: - for i in in_id: - new.fix_symbolic(i,j,'0') - new.mask += 1. - - for j in out_id: - for i in in_id: - if j < out_dim: - new.funs[j][i] = old.funs[j][i] - new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i] - new.funs_sympy[j][i] = old.funs_sympy[j][i] - new.funs_name[j][i] = old.funs_name[j][i] - new.affine.data[j][i] = old.affine.data[j][i] - - self.symbolic_fun[l] = new - self.act_fun[l] = KANLayer(in_dim, out_dim + n_added_subnodes, num=self.grid, k=self.k) - self.act_fun[l].mask *= 0. - - self.node_scale[l].data = torch.cat([self.node_scale[l].data, torch.ones(n_added_nodes, device=self.device)]) - self.node_bias[l].data = torch.cat([self.node_bias[l].data, torch.zeros(n_added_nodes, device=self.device)]) - self.subnode_scale[l].data = torch.cat([self.subnode_scale[l].data, torch.ones(n_added_subnodes, device=self.device)]) - self.subnode_bias[l].data = torch.cat([self.subnode_bias[l].data, torch.zeros(n_added_subnodes, device=self.device)]) - - if added_dim == 'in': - new = Symbolic_KANLayer(in_dim + n_added_nodes, out_dim) - old = self.symbolic_fun[l] - in_id = np.arange(in_dim + n_added_nodes) - out_id = np.arange(out_dim) - - for j in out_id: - for i in in_id: - new.fix_symbolic(i,j,'0') - new.mask += 1. - - for j in out_id: - for i in in_id: - if i < in_dim: - new.funs[j][i] = old.funs[j][i] - new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i] - new.funs_sympy[j][i] = old.funs_sympy[j][i] - new.funs_name[j][i] = old.funs_name[j][i] - new.affine.data[j][i] = old.affine.data[j][i] - - self.symbolic_fun[l] = new - self.act_fun[l] = KANLayer(in_dim + n_added_nodes, out_dim, num=self.grid, k=self.k) - self.act_fun[l].mask *= 0. - - _expand(layer_id-1, n_added_nodes, sum_bool, mult_arity, added_dim='out') - _expand(layer_id, n_added_nodes, sum_bool, mult_arity, added_dim='in') - if sum_bool: - self.width[layer_id][0] += n_added_nodes - else: - if isinstance(mult_arity, int): - mult_arity = [mult_arity] * n_added_nodes - - self.width[layer_id][1] += n_added_nodes - self.mult_arity[layer_id] += mult_arity - - def perturb(self, mag=1.0, mode='non-intrusive'): - ''' - preturb a network. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb. - - Args: - ----- - mag : float - perturbation magnitude - mode : str - pertubatation mode, choices = {'non-intrusive', 'all', 'minimal'} - - Returns: - -------- - None - ''' - perturb_bool = {} - - if mode == 'all': - perturb_bool['aa_a'] = True - perturb_bool['aa_i'] = True - perturb_bool['ai'] = True - perturb_bool['ia'] = True - perturb_bool['ii'] = True - elif mode == 'non-intrusive': - perturb_bool['aa_a'] = False - perturb_bool['aa_i'] = False - perturb_bool['ai'] = True - perturb_bool['ia'] = False - perturb_bool['ii'] = True - elif mode == 'minimal': - perturb_bool['aa_a'] = True - perturb_bool['aa_i'] = False - perturb_bool['ai'] = False - perturb_bool['ia'] = False - perturb_bool['ii'] = False - else: - raise Exception('mode not recognized, valid modes are \'all\', \'non-intrusive\', \'minimal\'.') - - for l in range(self.depth): - funs_name = self.symbolic_fun[l].funs_name - for j in range(self.width_out[l+1]): - for i in range(self.width_in[l]): - out_array = list(np.array(self.symbolic_fun[l].funs_name)[j]) - in_array = list(np.array(self.symbolic_fun[l].funs_name)[:,i]) - out_active = len([i for i, x in enumerate(out_array) if x != "0"]) > 0 - in_active = len([i for i, x in enumerate(in_array) if x != "0"]) > 0 - dic = {True: 'a', False: 'i'} - edge_type = dic[in_active] + dic[out_active] - - if l < self.depth - 1 or mode != 'non-intrusive': - - if edge_type == 'aa': - if self.symbolic_fun[l].funs_name[j][i] == '0': - edge_type += '_i' - else: - edge_type += '_a' - - if perturb_bool[edge_type]: - self.act_fun[l].mask.data[i][j] = mag - - if l == self.depth - 1 and mode == 'non-intrusive': - - self.act_fun[l].mask.data[i][j] = torch.tensor(1.) - self.act_fun[l].scale_base.data[i][j] = torch.tensor(0.) - self.act_fun[l].scale_sp.data[i][j] = torch.tensor(0.) - - self.get_act(self.cache_data) - - self.log_history('perturb') - - - def module(self, start_layer, chain): - ''' - specify network modules - - Args: - ----- - start_layer : int - the earliest layer of the module - chain : str - specify neurons in the module - - Returns: - -------- - None - ''' - #chain = '[-1]->[-1,-2]->[-1]->[-1]' - groups = chain.split('->') - n_total_layers = len(groups)//2 - #start_layer = 0 - - for l in range(n_total_layers): - current_layer = cl = start_layer + l - id_in = [int(i) for i in groups[2*l][1:-1].split(',')] - id_out = [int(i) for i in groups[2*l+1][1:-1].split(',')] - - in_dim = self.width_in[cl] - out_dim = self.width_out[cl+1] - id_in_other = list(set(range(in_dim)) - set(id_in)) - id_out_other = list(set(range(out_dim)) - set(id_out)) - self.act_fun[cl].mask.data[np.ix_(id_in_other,id_out)] = 0. - self.act_fun[cl].mask.data[np.ix_(id_in,id_out_other)] = 0. - self.symbolic_fun[cl].mask.data[np.ix_(id_out,id_in_other)] = 0. - self.symbolic_fun[cl].mask.data[np.ix_(id_out_other,id_in)] = 0. - - self.log_history('module') - - def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False): - ''' - turn KAN into a tree - ''' - if x == None: - x = self.cache_data - plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test, verbose=verbose) - - - def speed(self, compile=False): - ''' - turn on KAN's speed mode - ''' - self.symbolic_enabled=False - self.save_act=False - self.auto_save=False - if compile == True: - return torch.compile(self) - else: - return self - - def get_act(self, x=None): - ''' - collect intermidate activations - ''' - if isinstance(x, dict): - x = x['train_input'] - if x == None: - if self.cache_data != None: - x = self.cache_data - else: - raise Exception("missing input data x") - save_act = self.save_act - self.save_act = True - self.forward(x) - self.save_act = save_act - - def get_fun(self, l, i, j): - ''' - get function (l,i,j) - ''' - inputs = self.spline_preacts[l][:,j,i].cpu().detach().numpy() - outputs = self.spline_postacts[l][:,j,i].cpu().detach().numpy() - # they are not ordered yet - rank = np.argsort(inputs) - inputs = inputs[rank] - outputs = outputs[rank] - plt.figure(figsize=(3,3)) - plt.plot(inputs, outputs, marker="o") - return inputs, outputs - - - def history(self, k='all'): - ''' - get history - ''' - with open(self.ckpt_path+'/history.txt', 'r') as f: - data = f.readlines() - n_line = len(data) - if k == 'all': - k = n_line - - data = data[-k:] - for line in data: - print(line[:-1]) - @property - def n_edge(self): - ''' - the number of active edges - ''' - depth = len(self.act_fun) - complexity = 0 - for l in range(depth): - complexity += torch.sum(self.act_fun[l].mask > 0.) - return complexity.item() - - def evaluate(self, dataset): - evaluation = {} - evaluation['test_loss'] = torch.sqrt(torch.mean((self.forward(dataset['test_input']) - dataset['test_label'])**2)).item() - evaluation['n_edge'] = self.n_edge - evaluation['n_grid'] = self.grid - # add other metrics (maybe accuracy) - return evaluation - - def swap(self, l, i1, i2, log_history=True): - - self.act_fun[l-1].swap(i1,i2,mode='out') - self.symbolic_fun[l-1].swap(i1,i2,mode='out') - self.act_fun[l].swap(i1,i2,mode='in') - self.symbolic_fun[l].swap(i1,i2,mode='in') - - def swap_(data, i1, i2): - data[i1], data[i2] = data[i2], data[i1] - - swap_(self.node_scale[l-1].data, i1, i2) - swap_(self.node_bias[l-1].data, i1, i2) - swap_(self.subnode_scale[l-1].data, i1, i2) - swap_(self.subnode_bias[l-1].data, i1, i2) - - if log_history: - self.log_history('swap') - - @property - def connection_cost(self): - - cc = 0. - for t in self.edge_scores: - - def get_coordinate(n): - return torch.linspace(0,1,steps=n+1, device=self.device)[:n] + 1/(2*n) - - in_dim = t.shape[0] - x_in = get_coordinate(in_dim) - - out_dim = t.shape[1] - x_out = get_coordinate(out_dim) - - dist = torch.abs(x_in[:,None] - x_out[None,:]) - cc += torch.sum(dist * t) - - return cc - - def auto_swap_l(self, l): - - num = self.width_in[1] - for i in range(num): - ccs = [] - for j in range(num): - self.swap(l,i,j,log_history=False) - self.get_act() - self.attribute() - cc = self.connection_cost.detach().clone() - ccs.append(cc) - self.swap(l,i,j,log_history=False) - j = torch.argmin(torch.tensor(ccs)) - self.swap(l,i,j,log_history=False) - - def auto_swap(self): - ''' - automatically swap neurons such as connection costs are minimized - ''' - depth = self.depth - for l in range(1, depth): - self.auto_swap_l(l) - - self.log_history('auto_swap') - -KAN = MultKAN diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/Symbolic_KANLayer-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/Symbolic_KANLayer-checkpoint.py deleted file mode 100644 index 51baf0af5bd4acda730c6612c37ff87cf48c9f80..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/Symbolic_KANLayer-checkpoint.py +++ /dev/null @@ -1,270 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -import sympy -from .utils import * - - - -class Symbolic_KANLayer(nn.Module): - ''' - KANLayer class - - Attributes: - ----------- - in_dim : int - input dimension - out_dim : int - output dimension - funs : 2D array of torch functions (or lambda functions) - symbolic functions (torch) - funs_avoid_singularity : 2D array of torch functions (or lambda functions) with singularity avoiding - funs_name : 2D arry of str - names of symbolic functions - funs_sympy : 2D array of sympy functions (or lambda functions) - symbolic functions (sympy) - affine : 3D array of floats - affine transformations of inputs and outputs - ''' - def __init__(self, in_dim=3, out_dim=2, device='cpu'): - ''' - initialize a Symbolic_KANLayer (activation functions are initialized to be identity functions) - - Args: - ----- - in_dim : int - input dimension - out_dim : int - output dimension - device : str - device - - Returns: - -------- - self - - Example - ------- - >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=3) - >>> len(sb.funs), len(sb.funs[0]) - ''' - super(Symbolic_KANLayer, self).__init__() - self.out_dim = out_dim - self.in_dim = in_dim - self.mask = torch.nn.Parameter(torch.zeros(out_dim, in_dim, device=device)).requires_grad_(False) - # torch - self.funs = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)] - self.funs_avoid_singularity = [[lambda x, y_th: ((), x*0.) for i in range(self.in_dim)] for j in range(self.out_dim)] - # name - self.funs_name = [['0' for i in range(self.in_dim)] for j in range(self.out_dim)] - # sympy - self.funs_sympy = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)] - ### make funs_name the only parameter, and make others as the properties of funs_name? - - self.affine = torch.nn.Parameter(torch.zeros(out_dim, in_dim, 4, device=device)) - # c*f(a*x+b)+d - - self.device = device - self.to(device) - - def to(self, device): - ''' - move to device - ''' - super(Symbolic_KANLayer, self).to(device) - self.device = device - return self - - def forward(self, x, singularity_avoiding=False, y_th=10.): - ''' - forward - - Args: - ----- - x : 2D array - inputs, shape (batch, input dimension) - singularity_avoiding : bool - if True, funs_avoid_singularity is used; if False, funs is used. - y_th : float - the singularity threshold - - Returns: - -------- - y : 2D array - outputs, shape (batch, output dimension) - postacts : 3D array - activations after activation functions but before being summed on nodes - - Example - ------- - >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=5) - >>> x = torch.normal(0,1,size=(100,3)) - >>> y, postacts = sb(x) - >>> y.shape, postacts.shape - (torch.Size([100, 5]), torch.Size([100, 5, 3])) - ''' - - batch = x.shape[0] - postacts = [] - - for i in range(self.in_dim): - postacts_ = [] - for j in range(self.out_dim): - if singularity_avoiding: - xij = self.affine[j,i,2]*self.funs_avoid_singularity[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1], torch.tensor(y_th))[1]+self.affine[j,i,3] - else: - xij = self.affine[j,i,2]*self.funs[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1])+self.affine[j,i,3] - postacts_.append(self.mask[j][i]*xij) - postacts.append(torch.stack(postacts_)) - - postacts = torch.stack(postacts) - postacts = postacts.permute(2,1,0,3)[:,:,:,0] - y = torch.sum(postacts, dim=2) - - return y, postacts - - - def get_subset(self, in_id, out_id): - ''' - get a smaller Symbolic_KANLayer from a larger Symbolic_KANLayer (used for pruning) - - Args: - ----- - in_id : list - id of selected input neurons - out_id : list - id of selected output neurons - - Returns: - -------- - spb : Symbolic_KANLayer - - Example - ------- - >>> sb_large = Symbolic_KANLayer(in_dim=10, out_dim=10) - >>> sb_small = sb_large.get_subset([0,9],[1,2,3]) - >>> sb_small.in_dim, sb_small.out_dim - ''' - sbb = Symbolic_KANLayer(self.in_dim, self.out_dim, device=self.device) - sbb.in_dim = len(in_id) - sbb.out_dim = len(out_id) - sbb.mask.data = self.mask.data[out_id][:,in_id] - sbb.funs = [[self.funs[j][i] for i in in_id] for j in out_id] - sbb.funs_avoid_singularity = [[self.funs_avoid_singularity[j][i] for i in in_id] for j in out_id] - sbb.funs_sympy = [[self.funs_sympy[j][i] for i in in_id] for j in out_id] - sbb.funs_name = [[self.funs_name[j][i] for i in in_id] for j in out_id] - sbb.affine.data = self.affine.data[out_id][:,in_id] - return sbb - - - def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-10,10), b_range=(-10,10), verbose=True): - ''' - fix an activation function to be symbolic - - Args: - ----- - i : int - the id of input neuron - j : int - the id of output neuron - fun_name : str - the name of the symbolic functions - x : 1D array - preactivations - y : 1D array - postactivations - a_range : tuple - sweeping range of a - b_range : tuple - sweeping range of a - verbose : bool - print more information if True - - Returns: - -------- - r2 (coefficient of determination) - - Example 1 - --------- - >>> # when x & y are not provided. Affine parameters are set to a = 1, b = 0, c = 1, d = 0 - >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2) - >>> sb.fix_symbolic(2,1,'sin') - >>> print(sb.funs_name) - >>> print(sb.affine) - - Example 2 - --------- - >>> # when x & y are provided, fit_params() is called to find the best fit coefficients - >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2) - >>> batch = 100 - >>> x = torch.linspace(-1,1,steps=batch) - >>> noises = torch.normal(0,1,(batch,)) * 0.02 - >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises - >>> sb.fix_symbolic(2,1,'sin',x,y) - >>> print(sb.funs_name) - >>> print(sb.affine[1,2,:].data) - ''' - if isinstance(fun_name,str): - fun = SYMBOLIC_LIB[fun_name][0] - fun_sympy = SYMBOLIC_LIB[fun_name][1] - fun_avoid_singularity = SYMBOLIC_LIB[fun_name][3] - self.funs_sympy[j][i] = fun_sympy - self.funs_name[j][i] = fun_name - - if x == None or y == None: - #initialzie from just fun - self.funs[j][i] = fun - self.funs_avoid_singularity[j][i] = fun_avoid_singularity - if random == False: - self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.]) - else: - self.affine.data[j][i] = torch.rand(4,) * 2 - 1 - return None - else: - #initialize from x & y and fun - params, r2 = fit_params(x,y,fun, a_range=a_range, b_range=b_range, verbose=verbose, device=self.device) - self.funs[j][i] = fun - self.funs_avoid_singularity[j][i] = fun_avoid_singularity - self.affine.data[j][i] = params - return r2 - else: - # if fun_name itself is a function - fun = fun_name - fun_sympy = fun_name - self.funs_sympy[j][i] = fun_sympy - self.funs_name[j][i] = "anonymous" - - self.funs[j][i] = fun - self.funs_avoid_singularity[j][i] = fun - if random == False: - self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.]) - else: - self.affine.data[j][i] = torch.rand(4,) * 2 - 1 - return None - - def swap(self, i1, i2, mode='in'): - ''' - swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out') - ''' - with torch.no_grad(): - def swap_list_(data, i1, i2, mode='in'): - - if mode == 'in': - for j in range(self.out_dim): - data[j][i1], data[j][i2] = data[j][i2], data[j][i1] - - elif mode == 'out': - data[i1], data[i2] = data[i2], data[i1] - - def swap_(data, i1, i2, mode='in'): - if mode == 'in': - data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone() - - elif mode == 'out': - data[i1], data[i2] = data[i2].clone(), data[i1].clone() - - swap_list_(self.funs_name,i1,i2,mode) - swap_list_(self.funs_sympy,i1,i2,mode) - swap_list_(self.funs_avoid_singularity,i1,i2,mode) - swap_(self.affine.data,i1,i2,mode) - swap_(self.mask.data,i1,i2,mode) diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/__init__-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/__init__-checkpoint.py deleted file mode 100644 index 1ce0e47b21c80dcb7007e9742ec09b79b245dff9..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/__init__-checkpoint.py +++ /dev/null @@ -1,3 +0,0 @@ -from .MultKAN import * -from .utils import * -#torch.use_deterministic_algorithms(True) \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/compiler-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/compiler-checkpoint.py deleted file mode 100644 index c8014829e83b1a8a67687643d5b17d02286c2d3e..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/compiler-checkpoint.py +++ /dev/null @@ -1,498 +0,0 @@ -from sympy import * -import sympy -import numpy as np -from kan.MultKAN import MultKAN -import torch - -def next_nontrivial_operation(expr, scale=1, bias=0): - ''' - remove the affine part of an expression - - Args: - ----- - expr : sympy expression - scale : float - bias : float - - Returns: - -------- - expr : sympy expression - scale : float - bias : float - - Example - ------- - >>> from kan.compiler import * - >>> from sympy import * - >>> input_vars = a, b = symbols('a b') - >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402 - >>> next_nontrivial_operation(expression) - ''' - if expr.func == Add or expr.func == Mul: - n_arg = len(expr.args) - n_num = 0 - n_var_id = [] - n_num_id = [] - var_args = [] - for i in range(n_arg): - is_number = expr.args[i].is_number - n_num += is_number - if not is_number: - n_var_id.append(i) - var_args.append(expr.args[i]) - else: - n_num_id.append(i) - if n_num > 0: - # trivial - if expr.func == Add: - for i in range(n_num): - if i == 0: - bias = expr.args[n_num_id[i]] - else: - bias += expr.args[n_num_id[i]] - if expr.func == Mul: - for i in range(n_num): - if i == 0: - scale = expr.args[n_num_id[i]] - else: - scale *= expr.args[n_num_id[i]] - - return next_nontrivial_operation(expr.func(*var_args), scale, bias) - else: - return expr, scale, bias - else: - return expr, scale, bias - - -def expr2kan(input_variables, expr, grid=5, k=3, auto_save=False): - ''' - compile a symbolic formula to a MultKAN - - Args: - ----- - input_variables : a list of sympy symbols - expr : sympy expression - grid : int - the number of grid intervals - k : int - spline order - auto_save : bool - if auto_save = True, models are automatically saved - - Returns: - -------- - MultKAN - - Example - ------- - >>> from kan.compiler import * - >>> from sympy import * - >>> input_vars = a, b = symbols('a b') - >>> expression = exp(sin(pi*a) + b**2) - >>> model = kanpiler(input_vars, expression) - >>> x = torch.rand(100,2) * 2 - 1 - >>> model(x) - >>> model.plot() - ''' - class Node: - def __init__(self, expr, mult_bool, depth, scale, bias, parent=None, mult_arity=None): - self.expr = expr - self.mult_bool = mult_bool - if self.mult_bool: - self.mult_arity = mult_arity - self.depth = depth - - if len(Nodes) <= depth: - Nodes.append([]) - index = 0 - else: - index = len(Nodes[depth]) - - Nodes[depth].append(self) - - self.index = index - if parent == None: - self.parent_index = None - else: - self.parent_index = parent.index - self.child_index = [] - - # update parent's child_index - if parent != None: - parent.child_index.append(self.index) - - - self.scale = scale - self.bias = bias - - - class SubNode: - def __init__(self, expr, depth, scale, bias, parent=None): - self.expr = expr - self.depth = depth - - if len(SubNodes) <= depth: - SubNodes.append([]) - index = 0 - else: - index = len(SubNodes[depth]) - - SubNodes[depth].append(self) - - self.index = index - self.parent_index = None # shape: (2,) - self.child_index = [] # shape: (n, 2) - - # update parent's child_index - parent.child_index.append(self.index) - - self.scale = scale - self.bias = bias - - - class Connection: - def __init__(self, affine, fun, fun_name, parent=None, child=None, power_exponent=None): - # connection = activation function that connects a subnode to a node in the next layer node - self.affine = affine #[1,0,1,0] # (a,b,c,d) - self.fun = fun # y = c*fun(a*x+b)+d - self.fun_name = fun_name - self.parent_index = parent.index - self.depth = parent.depth - self.child_index = child.index - self.power_exponent = power_exponent # if fun == Pow - Connections[(self.depth,self.parent_index,self.child_index)] = self - - def create_node(expr, parent=None, n_layer=None): - #print('before', expr) - expr, scale, bias = next_nontrivial_operation(expr) - #print('after', expr) - if parent == None: - depth = 0 - else: - depth = parent.depth - - - if expr.func == Mul: - mult_arity = len(expr.args) - node = Node(expr, True, depth, scale, bias, parent=parent, mult_arity=mult_arity) - # create mult_arity SubNodes, + 1 - for i in range(mult_arity): - # create SubNode - expr_i, scale, bias = next_nontrivial_operation(expr.args[i]) - subnode = SubNode(expr_i, node.depth+1, scale, bias, parent=node) - if expr_i.func == Add: - for j in range(len(expr_i.args)): - expr_ij, scale, bias = next_nontrivial_operation(expr_i.args[j]) - # expr_ij is impossible to be Add, should be Mul or 1D - if expr_ij.func == Mul: - #print(expr_ij) - # create a node with expr_ij - new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer) - # create a connection which is a linear function - c = Connection([1,0,float(scale),float(bias)], lambda x: x, 'x', parent=subnode, child=new_node) - - elif expr_ij.func == Symbol: - #print(expr_ij) - new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer) - c = Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node) - - else: - # 1D function case - # create a node with expr_ij.args[0] - new_node = create_node(expr_ij.args[0], parent=subnode, n_layer=n_layer) - # create 1D function expr_ij.func - if expr_ij.func == Pow: - power_exponent = expr_ij.args[1] - else: - power_exponent = None - Connection([1,0,float(scale),float(bias)], expr_ij.func, fun_name = expr_ij.func, parent=subnode, child=new_node, power_exponent=power_exponent) - - - elif expr_i.func == Mul: - # create a node with expr_i - new_node = create_node(expr_i, parent=subnode, n_layer=n_layer) - # create 1D function, linear - Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node) - - elif expr_i.func == Symbol: - new_node = create_node(expr_i, parent=subnode, n_layer=n_layer) - Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node) - - else: - # 1D functions - # create a node with expr_i.args[0] - new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer) - # create 1D function expr_i.func - if expr_i.func == Pow: - power_exponent = expr_i.args[1] - else: - power_exponent = None - Connection([1,0,1,0], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent) - - elif expr.func == Add: - - node = Node(expr, False, depth, scale, bias, parent=parent) - subnode = SubNode(expr, node.depth+1, 1, 0, parent=node) - - for i in range(len(expr.args)): - expr_i, scale, bias = next_nontrivial_operation(expr.args[i]) - if expr_i.func == Mul: - # create a node with expr_i - new_node = create_node(expr_i, parent=subnode, n_layer=n_layer) - # create a connection which is a linear function - Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node) - - elif expr_i.func == Symbol: - new_node = create_node(expr_i, parent=subnode, n_layer=n_layer) - Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node) - - else: - # 1D function case - # create a node with expr_ij.args[0] - new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer) - # create 1D function expr_i.func - if expr_i.func == Pow: - power_exponent = expr_i.args[1] - else: - power_exponent = None - Connection([1,0,float(scale),float(bias)], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent) - - elif expr.func == Symbol: - # expr.func is a symbol (one of input variables) - if n_layer == None: - node = Node(expr, False, depth, scale, bias, parent=parent) - else: - node = Node(expr, False, depth, scale, bias, parent=parent) - return_node = node - for i in range(n_layer - depth): - subnode = SubNode(expr, node.depth+1, 1, 0, parent=node) - node = Node(expr, False, subnode.depth, 1, 0, parent=subnode) - Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=node) - node = return_node - - Start_Nodes.append(node) - - else: - # expr.func is 1D function - #print(expr, scale, bias) - node = Node(expr, False, depth, scale, bias, parent=parent) - expr_i, scale, bias = next_nontrivial_operation(expr.args[0]) - subnode = SubNode(expr_i, node.depth+1, 1, 0, parent=node) - # create a node with expr_i.args[0] - new_node = create_node(expr.args[0], parent=subnode, n_layer=n_layer) - # create 1D function expr_i.func - if expr.func == Pow: - power_exponent = expr.args[1] - else: - power_exponent = None - Connection([1,0,1,0], expr.func, fun_name = expr.func, parent=subnode, child=new_node, power_exponent=power_exponent) - - return node - - Nodes = [[]] - SubNodes = [[]] - Connections = {} - Start_Nodes = [] - - create_node(expr, n_layer=None) - - n_layer = len(Nodes) - 1 - - Nodes = [[]] - SubNodes = [[]] - Connections = {} - Start_Nodes = [] - - create_node(expr, n_layer=n_layer) - - # move affine parameters in leaf nodes to connections - for node in Start_Nodes: - c = Connections[(node.depth,node.parent_index,node.index)] - c.affine[0] = float(node.scale) - c.affine[1] = float(node.bias) - node.scale = 1. - node.bias = 0. - - #input_variables = symbol - node2var = [] - for node in Start_Nodes: - for i in range(len(input_variables)): - if node.expr == input_variables[i]: - node2var.append(i) - - # Nodes - n_mult = [] - n_sum = [] - for layer in Nodes: - n_mult.append(0) - n_sum.append(0) - for node in layer: - if node.mult_bool == True: - n_mult[-1] += 1 - else: - n_sum[-1] += 1 - - # depth - n_layer = len(Nodes) - 1 - - # converter - # input tree node id, output kan node id (distinguish sum and mult node) - # input tree subnode id, output tree subnode id - # node id - subnode_index_convert = {} - node_index_convert = {} - connection_index_convert = {} - mult_arities = [] - for layer_id in range(n_layer+1): - mult_arity = [] - i_sum = 0 - i_mult = 0 - for i in range(len(Nodes[layer_id])): - node = Nodes[layer_id][i] - if node.mult_bool == True: - kan_node_id = n_sum[layer_id] + i_mult - arity = len(node.child_index) - for i in range(arity): - subnode = SubNodes[node.depth+1][node.child_index[i]] - kan_subnode_id = n_sum[layer_id] + np.sum(mult_arity) + i - subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id)) - i_mult += 1 - mult_arity.append(arity) - else: - kan_node_id = i_sum - if len(node.child_index) > 0: - subnode = SubNodes[node.depth+1][node.child_index[0]] - kan_subnode_id = i_sum - subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id)) - i_sum += 1 - - if layer_id == n_layer: - # input layer - node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(node2var[kan_node_id])) - else: - node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(kan_node_id)) - - # node: depth (node.depth -> n_layer - node.depth) - # width (node.index -> kan_node_id) - # subnode: depth (subnode.depth -> n_layer - subnode.depth) - # width (subnote.index -> kan_subnode_id) - mult_arities.append(mult_arity) - - for index in list(Connections.keys()): - depth, subnode_id, node_id = index - # to int(n_layer-depth), - _, kan_subnode_id = subnode_index_convert[(depth, subnode_id)] - _, kan_node_id = node_index_convert[(depth, node_id)] - connection_index_convert[(depth, subnode_id, node_id)] = (n_layer-depth, kan_subnode_id, kan_node_id) - - - n_sum.reverse() - n_mult.reverse() - mult_arities.reverse() - - width = [[n_sum[i], n_mult[i]] for i in range(len(n_sum))] - width[0][0] = len(input_variables) - - # allow pass in other parameters (probably as a dictionary) in sf2kan, including grid k etc. - model = MultKAN(width=width, mult_arity=mult_arities, grid=grid, k=k, auto_save=auto_save) - - # clean the graph - for l in range(model.depth): - for i in range(model.width_in[l]): - for j in range(model.width_out[l+1]): - model.fix_symbolic(l,i,j,'0',fit_params_bool=False) - - # Nodes - Nodes_flat = [x for xs in Nodes for x in xs] - - self = model - - for node in Nodes_flat: - node_depth = node.depth - node_index = node.index - kan_node_depth, kan_node_index = node_index_convert[(node_depth,node_index)] - #print(kan_node_depth, kan_node_index) - if kan_node_depth > 0: - self.node_scale[kan_node_depth-1].data[kan_node_index] = float(node.scale) - self.node_bias[kan_node_depth-1].data[kan_node_index] = float(node.bias) - - - # SubNodes - SubNodes_flat = [x for xs in SubNodes for x in xs] - - for subnode in SubNodes_flat: - subnode_depth = subnode.depth - subnode_index = subnode.index - kan_subnode_depth, kan_subnode_index = subnode_index_convert[(subnode_depth,subnode_index)] - #print(kan_subnode_depth, kan_subnode_index) - self.subnode_scale[kan_subnode_depth].data[kan_subnode_index] = float(subnode.scale) - self.subnode_bias[kan_subnode_depth].data[kan_subnode_index] = float(subnode.bias) - - # Connections - Connections_flat = list(Connections.values()) - - for connection in Connections_flat: - c_depth = connection.depth - c_j = connection.parent_index - c_i = connection.child_index - kc_depth, kc_j, kc_i = connection_index_convert[(c_depth, c_j, c_i)] - - # get symbolic fun_name - fun_name = connection.fun_name - #if fun_name == Pow: - # print(connection.power_exponent) - - if fun_name == 'x': - kfun_name = 'x' - elif fun_name == exp: - kfun_name = 'exp' - elif fun_name == sin: - kfun_name = 'sin' - elif fun_name == cos: - kfun_name = 'cos' - elif fun_name == tan: - kfun_name = 'tan' - elif fun_name == sqrt: - kfun_name = 'sqrt' - elif fun_name == log: - kfun_name = 'log' - elif fun_name == tanh: - kfun_name = 'tanh' - elif fun_name == asin: - kfun_name = 'arcsin' - elif fun_name == acos: - kfun_name = 'arccos' - elif fun_name == atan: - kfun_name = 'arctan' - elif fun_name == atanh: - kfun_name = 'arctanh' - elif fun_name == sign: - kfun_name = 'sgn' - elif fun_name == Pow: - alpha = connection.power_exponent - if alpha == Rational(1,2): - kfun_name = 'x^0.5' - elif alpha == - Rational(1,2): - kfun_name = '1/x^0.5' - elif alpha == Rational(3,2): - kfun_name = 'x^1.5' - else: - alpha = int(connection.power_exponent) - if alpha > 0: - if alpha == 1: - kfun_name = 'x' - else: - kfun_name = f'x^{alpha}' - else: - if alpha == -1: - kfun_name = '1/x' - else: - kfun_name = f'1/x^{-alpha}' - - model.fix_symbolic(kc_depth, kc_i, kc_j, kfun_name, fit_params_bool=False) - model.symbolic_fun[kc_depth].affine.data.reshape(self.width_out[kc_depth+1], self.width_in[kc_depth], 4)[kc_j][kc_i] = torch.tensor(connection.affine) - - return model - - -sf2kan = kanpiler = expr2kan \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/experiment-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/experiment-checkpoint.py deleted file mode 100644 index 9ab9e9de3197e10e2ce2c57b6db361cbf18628bb..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/experiment-checkpoint.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -from .MultKAN import * - - -def runner1(width, dataset, grids=[5,10,20], steps=20, lamb=0.001, prune_round=3, refine_round=3, edge_th=1e-2, node_th=1e-2, metrics=None, seed=1): - - result = {} - result['test_loss'] = [] - result['c'] = [] - result['G'] = [] - result['id'] = [] - if metrics != None: - for i in range(len(metrics)): - result[metrics[i].__name__] = [] - - def collect(evaluation): - result['test_loss'].append(evaluation['test_loss']) - result['c'].append(evaluation['n_edge']) - result['G'].append(evaluation['n_grid']) - result['id'].append(f'{model.round}.{model.state_id}') - if metrics != None: - for i in range(len(metrics)): - result[metrics[i].__name__].append(metrics[i](model, dataset).item()) - - for i in range(prune_round): - # train and prune - if i == 0: - model = KAN(width=width, grid=grids[0], seed=seed) - else: - model = model.rewind(f'{i-1}.{2*i}') - - model.fit(dataset, steps=steps, lamb=lamb) - model = model.prune(edge_th=edge_th, node_th=node_th) - evaluation = model.evaluate(dataset) - collect(evaluation) - - for j in range(refine_round): - model = model.refine(grids[j]) - model.fit(dataset, steps=steps) - evaluation = model.evaluate(dataset) - collect(evaluation) - - for key in list(result.keys()): - result[key] = np.array(result[key]) - - return result - - -def pareto_frontier(x,y): - - pf_id = np.where(np.sum((x[:,None] <= x[None,:]) * (y[:,None] <= y[None,:]), axis=0) == 1)[0] - x_pf = x[pf_id] - y_pf = y[pf_id] - - return x_pf, y_pf, pf_id \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/feynman-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/feynman-checkpoint.py deleted file mode 100644 index 6cc55e96ff947391434abf9168b5d6792a3d7119..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/feynman-checkpoint.py +++ /dev/null @@ -1,739 +0,0 @@ -from sympy import * -import torch - - -def get_feynman_dataset(name): - - global symbols - - tpi = torch.tensor(torch.pi) - - if name == 'test': - symbol = x, y = symbols('x, y') - expr = (x+y) * sin(exp(2*y)) - f = lambda x: (x[:,[0]] + x[:,[1]])*torch.sin(torch.exp(2*x[:,[1]])) - ranges = [-1,1] - - if name == 'I.6.20a' or name == 1: - symbol = theta = symbols('theta') - symbol = [symbol] - expr = exp(-theta**2/2)/sqrt(2*pi) - f = lambda x: torch.exp(-x[:,[0]]**2/2)/torch.sqrt(2*tpi) - ranges = [[-3,3]] - - if name == 'I.6.20' or name == 2: - symbol = theta, sigma = symbols('theta sigma') - expr = exp(-theta**2/(2*sigma**2))/sqrt(2*pi*sigma**2) - f = lambda x: torch.exp(-x[:,[0]]**2/(2*x[:,[1]]**2))/torch.sqrt(2*tpi*x[:,[1]]**2) - ranges = [[-1,1],[0.5,2]] - - if name == 'I.6.20b' or name == 3: - symbol = theta, theta1, sigma = symbols('theta theta1 sigma') - expr = exp(-(theta-theta1)**2/(2*sigma**2))/sqrt(2*pi*sigma**2) - f = lambda x: torch.exp(-(x[:,[0]]-x[:,[1]])**2/(2*x[:,[2]]**2))/torch.sqrt(2*tpi*x[:,[2]]**2) - ranges = [[-1.5,1.5],[-1.5,1.5],[0.5,2]] - - if name == 'I.8.4' or name == 4: - symbol = x1, x2, y1, y2 = symbols('x1 x2 y1 y2') - expr = sqrt((x2-x1)**2+(y2-y1)**2) - f = lambda x: torch.sqrt((x[:,[1]]-x[:,[0]])**2+(x[:,[3]]-x[:,[2]])**2) - ranges = [[-1,1],[-1,1],[-1,1],[-1,1]] - - if name == 'I.9.18' or name == 5: - symbol = G, m1, m2, x1, x2, y1, y2, z1, z2 = symbols('G m1 m2 x1 x2 y1 y2 z1 z2') - expr = G*m1*m2/((x2-x1)**2+(y2-y1)**2+(z2-z1)**2) - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/((x[:,[3]]-x[:,[4]])**2+(x[:,[5]]-x[:,[6]])**2+(x[:,[7]]-x[:,[8]])**2) - ranges = [[-1,1],[-1,1],[-1,1],[-1,-0.5],[0.5,1],[-1,-0.5],[0.5,1],[-1,-0.5],[0.5,1]] - - if name == 'I.10.7' or name == 6: - symbol = m0, v, c = symbols('m0 v c') - expr = m0/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[0,1],[0,1],[1,2]] - - if name == 'I.11.19' or name == 7: - symbol = x1, y1, x2, y2, x3, y3 = symbols('x1 y1 x2 y2 x3 y3') - expr = x1*y1 + x2*y2 + x3*y3 - f = lambda x: x[:,[0]]*x[:,[1]] + x[:,[2]]*x[:,[3]] + x[:,[4]]*x[:,[5]] - ranges = [-1,1] - - if name == 'I.12.1' or name == 8: - symbol = mu, Nn = symbols('mu N_n') - expr = mu * Nn - f = lambda x: x[:,[0]]*x[:,[1]] - ranges = [-1,1] - - if name == 'I.12.2' or name == 9: - symbol = q1, q2, eps, r = symbols('q1 q2 epsilon r') - expr = q1*q2/(4*pi*eps*r**2) - f = lambda x: x[:,[0]]*x[:,[1]]/(4*tpi*x[:,[2]]*x[:,[3]]**2) - ranges = [[-1,1],[-1,1],[0.5,2],[0.5,2]] - - if name == 'I.12.4' or name == 10: - symbol = q1, eps, r = symbols('q1 epsilon r') - expr = q1/(4*pi*eps*r**2) - f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]]**2) - ranges = [[-1,1],[0.5,2],[0.5,2]] - - if name == 'I.12.5' or name == 11: - symbol = q2, Ef = symbols('q2, E_f') - expr = q2*Ef - f = lambda x: x[:,[0]]*x[:,[1]] - ranges = [-1,1] - - if name == 'I.12.11' or name == 12: - symbol = q, Ef, B, v, theta = symbols('q E_f B v theta') - expr = q*(Ef + B*v*sin(theta)) - f = lambda x: x[:,[0]]*(x[:,[1]]+x[:,[2]]*x[:,[3]]*torch.sin(x[:,[4]])) - ranges = [[-1,1],[-1,1],[-1,1],[-1,1],[0,2*tpi]] - - if name == 'I.13.4' or name == 13: - symbol = m, v, u, w = symbols('m u v w') - expr = 1/2*m*(v**2+u**2+w**2) - f = lambda x: 1/2*x[:,[0]]*(x[:,[1]]**2+x[:,[2]]**2+x[:,[3]]**2) - ranges = [[-1,1],[-1,1],[-1,1],[-1,1]] - - if name == 'I.13.12' or name == 14: - symbol = G, m1, m2, r1, r2 = symbols('G m1 m2 r1 r2') - expr = G*m1*m2*(1/r2-1/r1) - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*(1/x[:,[4]]-1/x[:,[3]]) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'I.14.3' or name == 15: - symbol = m, g, z = symbols('m g z') - expr = m*g*z - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]] - ranges = [[0,1],[0,1],[-1,1]] - - if name == 'I.14.4' or name == 16: - symbol = ks, x = symbols('k_s x') - expr = 1/2*ks*x**2 - f = lambda x: 1/2*x[:,[0]]*x[:,[1]]**2 - ranges = [[0,1],[-1,1]] - - if name == 'I.15.3x' or name == 17: - symbol = x, u, t, c = symbols('x u t c') - expr = (x-u*t)/sqrt(1-u**2/c**2) - f = lambda x: (x[:,[0]] - x[:,[1]]*x[:,[2]])/torch.sqrt(1-x[:,[1]]**2/x[:,[3]]**2) - ranges = [[-1,1],[-1,1],[-1,1],[1,2]] - - if name == 'I.15.3t' or name == 18: - symbol = t, u, x, c = symbols('t u x c') - expr = (t-u*x/c**2)/sqrt(1-u**2/c**2) - f = lambda x: (x[:,[0]] - x[:,[1]]*x[:,[2]]/x[:,[3]]**2)/torch.sqrt(1-x[:,[1]]**2/x[:,[3]]**2) - ranges = [[-1,1],[-1,1],[-1,1],[1,2]] - - if name == 'I.15.10' or name == 19: - symbol = m0, v, c = symbols('m0 v c') - expr = m0*v/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]*x[:,[1]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[-1,1],[-0.9,0.9],[1.1,2]] - - if name == 'I.16.6' or name == 20: - symbol = u, v, c = symbols('u v c') - expr = (u+v)/(1+u*v/c**2) - f = lambda x: x[:,[0]]*x[:,[1]]/(1+x[:,[0]]*x[:,[1]]/x[:,[2]]**2) - ranges = [[-0.8,0.8],[-0.8,0.8],[1,2]] - - if name == 'I.18.4' or name == 21: - symbol = m1, r1, m2, r2 = symbols('m1 r1 m2 r2') - expr = (m1*r1+m2*r2)/(m1+m2) - f = lambda x: (x[:,[0]]*x[:,[1]]+x[:,[2]]*x[:,[3]])/(x[:,[0]]+x[:,[2]]) - ranges = [[0.5,1],[-1,1],[0.5,1],[-1,1]] - - if name == 'I.18.4' or name == 22: - symbol = r, F, theta = symbols('r F theta') - expr = r*F*sin(theta) - f = lambda x: x[:,[0]]*x[:,[1]]*torch.sin(x[:,[2]]) - ranges = [[-1,1],[-1,1],[0,2*tpi]] - - if name == 'I.18.16' or name == 23: - symbol = m, r, v, theta = symbols('m r v theta') - expr = m*r*v*sin(theta) - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*torch.sin(x[:,[3]]) - ranges = [[-1,1],[-1,1],[-1,1],[0,2*tpi]] - - if name == 'I.24.6' or name == 24: - symbol = m, omega, omega0, x = symbols('m omega omega_0 x') - expr = 1/4*m*(omega**2+omega0**2)*x**2 - f = lambda x: 1/4*x[:,[0]]*(x[:,[1]]**2+x[:,[2]]**2)*x[:,[3]]**2 - ranges = [[0,1],[-1,1],[-1,1],[-1,1]] - - if name == 'I.25.13' or name == 25: - symbol = q, C = symbols('q C') - expr = q/C - f = lambda x: x[:,[0]]/x[:,[1]] - ranges = [[-1,1],[0.5,2]] - - if name == 'I.26.2' or name == 26: - symbol = n, theta2 = symbols('n theta2') - expr = asin(n*sin(theta2)) - f = lambda x: torch.arcsin(x[:,[0]]*torch.sin(x[:,[1]])) - ranges = [[0,0.99],[0,2*tpi]] - - if name == 'I.27.6' or name == 27: - symbol = d1, d2, n = symbols('d1 d2 n') - expr = 1/(1/d1+n/d2) - f = lambda x: 1/(1/x[:,[0]]+x[:,[2]]/x[:,[1]]) - ranges = [[0.5,2],[1,2],[0.5,2]] - - if name == 'I.29.4' or name == 28: - symbol = omega, c = symbols('omega c') - expr = omega/c - f = lambda x: x[:,[0]]/x[:,[1]] - ranges = [[0,1],[0.5,2]] - - if name == 'I.29.16' or name == 29: - symbol = x1, x2, theta1, theta2 = symbols('x1 x2 theta1 theta2') - expr = sqrt(x1**2+x2**2-2*x1*x2*cos(theta1-theta2)) - f = lambda x: torch.sqrt(x[:,[0]]**2+x[:,[1]]**2-2*x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]]-x[:,[3]])) - ranges = [[-1,1],[-1,1],[0,2*tpi],[0,2*tpi]] - - if name == 'I.30.3' or name == 30: - symbol = I0, n, theta = symbols('I_0 n theta') - expr = I0 * sin(n*theta/2)**2 / sin(theta/2) ** 2 - f = lambda x: x[:,[0]] * torch.sin(x[:,[1]]*x[:,[2]]/2)**2 / torch.sin(x[:,[2]]/2)**2 - ranges = [[0,1],[0,4],[0.4*tpi,1.6*tpi]] - - if name == 'I.30.5' or name == 31: - symbol = lamb, n, d = symbols('lambda n d') - expr = asin(lamb/(n*d)) - f = lambda x: torch.arcsin(x[:,[0]]/(x[:,[1]]*x[:,[2]])) - ranges = [[-1,1],[1,1.5],[1,1.5]] - - if name == 'I.32.5' or name == 32: - symbol = q, a, eps, c = symbols('q a epsilon c') - expr = q**2*a**2/(eps*c**3) - f = lambda x: x[:,[0]]**2*x[:,[1]]**2/(x[:,[2]]*x[:,[3]]**3) - ranges = [[-1,1],[-1,1],[0.5,2],[0.5,2]] - - if name == 'I.32.17' or name == 33: - symbol = eps, c, Ef, r, omega, omega0 = symbols('epsilon c E_f r omega omega_0') - expr = nsimplify((1/2*eps*c*Ef**2)*(8*pi*r**2/3)*(omega**4/(omega**2-omega0**2)**2)) - f = lambda x: (1/2*x[:,[0]]*x[:,[1]]*x[:,[2]]**2)*(8*tpi*x[:,[3]]**2/3)*(x[:,[4]]**4/(x[:,[4]]**2-x[:,[5]]**2)**2) - ranges = [[0,1],[0,1],[-1,1],[0,1],[0,1],[1,2]] - - if name == 'I.34.8' or name == 34: - symbol = q, V, B, p = symbols('q V B p') - expr = q*V*B/p - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[-1,1],[-1,1],[-1,1],[0.5,2]] - - if name == 'I.34.10' or name == 35: - symbol = omega0, v, c = symbols('omega_0 v c') - expr = omega0/(1-v/c) - f = lambda x: x[:,[0]]/(1-x[:,[1]]/x[:,[2]]) - ranges = [[0,1],[0,0.9],[1.1,2]] - - if name == 'I.34.14' or name == 36: - symbol = omega0, v, c = symbols('omega_0 v c') - expr = omega0 * (1+v/c)/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]*(1+x[:,[1]]/x[:,[2]])/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[0,1],[-0.9,0.9],[1.1,2]] - - if name == 'I.34.27' or name == 37: - symbol = hbar, omega = symbols('hbar omega') - expr = hbar * omega - f = lambda x: x[:,[0]]*x[:,[1]] - ranges = [[-1,1],[-1,1]] - - if name == 'I.37.4' or name == 38: - symbol = I1, I2, delta = symbols('I_1 I_2 delta') - expr = I1 + I2 + 2*sqrt(I1*I2)*cos(delta) - f = lambda x: x[:,[0]] + x[:,[1]] + 2*torch.sqrt(x[:,[0]]*x[:,[1]])*torch.cos(x[:,[2]]) - ranges = [[0.1,1],[0.1,1],[0,2*tpi]] - - if name == 'I.38.12' or name == 39: - symbol = eps, hbar, m, q = symbols('epsilon hbar m q') - expr = 4*pi*eps*hbar**2/(m*q**2) - f = lambda x: 4*tpi*x[:,[0]]*x[:,[1]]**2/(x[:,[2]]*x[:,[3]]**2) - ranges = [[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'I.39.10' or name == 40: - symbol = pF, V = symbols('p_F V') - expr = 3/2 * pF * V - f = lambda x: 3/2 * x[:,[0]] * x[:,[1]] - ranges = [[0,1],[0,1]] - - if name == 'I.39.11' or name == 41: - symbol = gamma, pF, V = symbols('gamma p_F V') - expr = pF * V/(gamma - 1) - f = lambda x: 1/(x[:,[0]]-1) * x[:,[1]] * x[:,[2]] - ranges = [[1.5,3],[0,1],[0,1]] - - if name == 'I.39.22' or name == 42: - symbol = n, kb, T, V = symbols('n k_b T V') - expr = n*kb*T/V - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'I.40.1' or name == 43: - symbol = n0, m, g, x, kb, T = symbols('n_0 m g x k_b T') - expr = n0 * exp(-m*g*x/(kb*T)) - f = lambda x: x[:,[0]] * torch.exp(-x[:,[1]]*x[:,[2]]*x[:,[3]]/(x[:,[4]]*x[:,[5]])) - ranges = [[0,1],[-1,1],[-1,1],[-1,1],[1,2],[1,2]] - - if name == 'I.41.16' or name == 44: - symbol = hbar, omega, c, kb, T = symbols('hbar omega c k_b T') - expr = hbar * omega**3/(pi**2*c**2*(exp(hbar*omega/(kb*T))-1)) - f = lambda x: x[:,[0]]*x[:,[1]]**3/(tpi**2*x[:,[2]]**2*(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[3]]*x[:,[4]]))-1)) - ranges = [[0.5,1],[0.5,1],[0.5,2],[0.5,2],[0.5,2]] - - if name == 'I.43.16' or name == 45: - symbol = mu, q, Ve, d = symbols('mu q V_e d') - expr = mu*q*Ve/d - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'I.43.31' or name == 46: - symbol = mu, kb, T = symbols('mu k_b T') - expr = mu*kb*T - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]] - ranges = [[0,1],[0,1],[0,1]] - - if name == 'I.43.43' or name == 47: - symbol = gamma, kb, v, A = symbols('gamma k_b v A') - expr = kb*v/A/(gamma-1) - f = lambda x: 1/(x[:,[0]]-1)*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[1.5,3],[0,1],[0,1],[0.5,2]] - - if name == 'I.44.4' or name == 48: - symbol = n, kb, T, V1, V2 = symbols('n k_b T V_1 V_2') - expr = n*kb*T*log(V2/V1) - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*torch.log(x[:,[4]]/x[:,[3]]) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'I.47.23' or name == 49: - symbol = gamma, p, rho = symbols('gamma p rho') - expr = sqrt(gamma*p/rho) - f = lambda x: torch.sqrt(x[:,[0]]*x[:,[1]]/x[:,[2]]) - ranges = [[0.1,1],[0.1,1],[0.5,2]] - - if name == 'I.48.20' or name == 50: - symbol = m, v, c = symbols('m v c') - expr = m*c**2/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]*x[:,[2]]**2/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[0,1],[-0.9,0.9],[1.1,2]] - - if name == 'I.50.26' or name == 51: - symbol = x1, alpha, omega, t = symbols('x_1 alpha omega t') - expr = x1*(cos(omega*t)+alpha*cos(omega*t)**2) - f = lambda x: x[:,[0]]*(torch.cos(x[:,[2]]*x[:,[3]])+x[:,[1]]*torch.cos(x[:,[2]]*x[:,[3]])**2) - ranges = [[0,1],[0,1],[0,2*tpi],[0,1]] - - if name == 'II.2.42' or name == 52: - symbol = kappa, T1, T2, A, d = symbols('kappa T_1 T_2 A d') - expr = kappa*(T2-T1)*A/d - f = lambda x: x[:,[0]]*(x[:,[2]]-x[:,[1]])*x[:,[3]]/x[:,[4]] - ranges = [[0,1],[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'II.3.24' or name == 53: - symbol = P, r = symbols('P r') - expr = P/(4*pi*r**2) - f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]**2) - ranges = [[0,1],[0.5,2]] - - if name == 'II.4.23' or name == 54: - symbol = q, eps, r = symbols('q epsilon r') - expr = q/(4*pi*eps*r) - f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]]) - ranges = [[0,1],[0.5,2],[0.5,2]] - - if name == 'II.6.11' or name == 55: - symbol = eps, pd, theta, r = symbols('epsilon p_d theta r') - expr = 1/(4*pi*eps)*pd*cos(theta)/r**2 - f = lambda x: 1/(4*tpi*x[:,[0]])*x[:,[1]]*torch.cos(x[:,[2]])/x[:,[3]]**2 - ranges = [[0.5,2],[0,1],[0,2*tpi],[0.5,2]] - - if name == 'II.6.15a' or name == 56: - symbol = eps, pd, z, x, y, r = symbols('epsilon p_d z x y r') - expr = 3/(4*pi*eps)*pd*z/r**5*sqrt(x**2+y**2) - f = lambda x: 3/(4*tpi*x[:,[0]])*x[:,[1]]*x[:,[2]]/x[:,[5]]**5*torch.sqrt(x[:,[3]]**2+x[:,[4]]**2) - ranges = [[0.5,2],[0,1],[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'II.6.15b' or name == 57: - symbol = eps, pd, r, theta = symbols('epsilon p_d r theta') - expr = 3/(4*pi*eps)*pd/r**3*cos(theta)*sin(theta) - f = lambda x: 3/(4*tpi*x[:,[0]])*x[:,[1]]/x[:,[2]]**3*torch.cos(x[:,[3]])*torch.sin(x[:,[3]]) - ranges = [[0.5,2],[0,1],[0.5,2],[0,2*tpi]] - - if name == 'II.8.7' or name == 58: - symbol = q, eps, d = symbols('q epsilon d') - expr = 3/5*q**2/(4*pi*eps*d) - f = lambda x: 3/5*x[:,[0]]**2/(4*tpi*x[:,[1]]*x[:,[2]]) - ranges = [[0,1],[0.5,2],[0.5,2]] - - if name == 'II.8.31' or name == 59: - symbol = eps, Ef = symbols('epsilon E_f') - expr = 1/2*eps*Ef**2 - f = lambda x: 1/2*x[:,[0]]*x[:,[1]]**2 - ranges = [[0,1],[0,1]] - - if name == 'I.10.9' or name == 60: - symbol = sigma, eps, chi = symbols('sigma epsilon chi') - expr = sigma/eps/(1+chi) - f = lambda x: x[:,[0]]/x[:,[1]]/(1+x[:,[2]]) - ranges = [[0,1],[0.5,2],[0,1]] - - if name == 'II.11.3' or name == 61: - symbol = q, Ef, m, omega0, omega = symbols('q E_f m omega_o omega') - expr = q*Ef/(m*(omega0**2-omega**2)) - f = lambda x: x[:,[0]]*x[:,[1]]/(x[:,[2]]*(x[:,[3]]**2-x[:,[4]]**2)) - ranges = [[0,1],[0,1],[0.5,2],[1.5,3],[0,1]] - - if name == 'II.11.17' or name == 62: - symbol = n0, pd, Ef, theta, kb, T = symbols('n_0 p_d E_f theta k_b T') - expr = n0*(1+pd*Ef*cos(theta)/(kb*T)) - f = lambda x: x[:,[0]]*(1+x[:,[1]]*x[:,[2]]*torch.cos(x[:,[3]])/(x[:,[4]]*x[:,[5]])) - ranges = [[0,1],[-1,1],[-1,1],[0,2*tpi],[0.5,2],[0.5,2]] - - - if name == 'II.11.20' or name == 63: - symbol = n, pd, Ef, kb, T = symbols('n p_d E_f k_b T') - expr = n*pd**2*Ef/(3*kb*T) - f = lambda x: x[:,[0]]*x[:,[1]]**2*x[:,[2]]/(3*x[:,[3]]*x[:,[4]]) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'II.11.27' or name == 64: - symbol = n, alpha, eps, Ef = symbols('n alpha epsilon E_f') - expr = n*alpha/(1-n*alpha/3)*eps*Ef - f = lambda x: x[:,[0]]*x[:,[1]]/(1-x[:,[0]]*x[:,[1]]/3)*x[:,[2]]*x[:,[3]] - ranges = [[0,1],[0,2],[0,1],[0,1]] - - if name == 'II.11.28' or name == 65: - symbol = n, alpha = symbols('n alpha') - expr = 1 + n*alpha/(1-n*alpha/3) - f = lambda x: 1 + x[:,[0]]*x[:,[1]]/(1-x[:,[0]]*x[:,[1]]/3) - ranges = [[0,1],[0,2]] - - if name == 'II.13.17' or name == 66: - symbol = eps, c, l, r = symbols('epsilon c l r') - expr = 1/(4*pi*eps*c**2)*(2*l/r) - f = lambda x: 1/(4*tpi*x[:,[0]]*x[:,[1]]**2)*(2*x[:,[2]]/x[:,[3]]) - ranges = [[0.5,2],[0.5,2],[0,1],[0.5,2]] - - if name == 'II.13.23' or name == 67: - symbol = rho, v, c = symbols('rho v c') - expr = rho/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[0,1],[0,1],[1,2]] - - if name == 'II.13.34' or name == 68: - symbol = rho, v, c = symbols('rho v c') - expr = rho*v/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]*x[:,[1]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[0,1],[0,1],[1,2]] - - if name == 'II.15.4' or name == 69: - symbol = muM, B, theta = symbols('mu_M B theta') - expr = - muM * B * cos(theta) - f = lambda x: - x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]]) - ranges = [[0,1],[0,1],[0,2*tpi]] - - if name == 'II.15.5' or name == 70: - symbol = pd, Ef, theta = symbols('p_d E_f theta') - expr = - pd * Ef * cos(theta) - f = lambda x: - x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]]) - ranges = [[0,1],[0,1],[0,2*tpi]] - - if name == 'II.21.32' or name == 71: - symbol = q, eps, r, v, c = symbols('q epsilon r v c') - expr = q/(4*pi*eps*r*(1-v/c)) - f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]]*(1-x[:,[3]]/x[:,[4]])) - ranges = [[0,1],[0.5,2],[0.5,2],[0,1],[1,2]] - - if name == 'II.24.17' or name == 72: - symbol = omega, c, d = symbols('omega c d') - expr = sqrt(omega**2/c**2-pi**2/d**2) - f = lambda x: torch.sqrt(x[:,[0]]**2/x[:,[1]]**2-tpi**2/x[:,[2]]**2) - ranges = [[1,1.5],[0.75,1],[1*tpi,1.5*tpi]] - - if name == 'II.27.16' or name == 73: - symbol = eps, c, Ef = symbols('epsilon c E_f') - expr = eps * c * Ef**2 - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]**2 - ranges = [[0,1],[0,1],[-1,1]] - - if name == 'II.27.18' or name == 74: - symbol = eps, Ef = symbols('epsilon E_f') - expr = eps * Ef**2 - f = lambda x: x[:,[0]]*x[:,[1]]**2 - ranges = [[0,1],[-1,1]] - - if name == 'II.34.2a' or name == 75: - symbol = q, v, r = symbols('q v r') - expr = q*v/(2*pi*r) - f = lambda x: x[:,[0]]*x[:,[1]]/(2*tpi*x[:,[2]]) - ranges = [[0,1],[0,1],[0.5,2]] - - if name == 'II.34.2' or name == 76: - symbol = q, v, r = symbols('q v r') - expr = q*v*r/2 - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/2 - ranges = [[0,1],[0,1],[0,1]] - - if name == 'II.34.11' or name == 77: - symbol = g, q, B, m = symbols('g q B m') - expr = g*q*B/(2*m) - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/(2*x[:,[3]]) - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'II.34.29a' or name == 78: - symbol = q, h, m = symbols('q h m') - expr = q*h/(4*pi*m) - f = lambda x: x[:,[0]]*x[:,[1]]/(4*tpi*x[:,[2]]) - ranges = [[0,1],[0,1],[0.5,2]] - - if name == 'II.34.29b' or name == 79: - symbol = g, mu, B, J, hbar = symbols('g mu B J hbar') - expr = g*mu*B*J/hbar - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*x[:,[3]]/x[:,[4]] - ranges = [[0,1],[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'II.35.18' or name == 80: - symbol = n0, mu, B, kb, T = symbols('n0 mu B k_b T') - expr = n0/(exp(mu*B/(kb*T))+exp(-mu*B/(kb*T))) - f = lambda x: x[:,[0]]/(torch.exp(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))+torch.exp(-x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'II.35.21' or name == 81: - symbol = n, mu, B, kb, T = symbols('n mu B k_b T') - expr = n*mu*tanh(mu*B/(kb*T)) - f = lambda x: x[:,[0]]*x[:,[1]]*torch.tanh(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]])) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'II.36.38' or name == 82: - symbol = mu, B, kb, T, alpha, M, eps, c = symbols('mu B k_b T alpha M epsilon c') - expr = mu*B/(kb*T) + mu*alpha*M/(eps*c**2*kb*T) - f = lambda x: x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]) + x[:,[0]]*x[:,[4]]*x[:,[5]]/(x[:,[6]]*x[:,[7]]**2*x[:,[2]]*x[:,[3]]) - ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'II.37.1' or name == 83: - symbol = mu, chi, B = symbols('mu chi B') - expr = mu*(1+chi)*B - f = lambda x: x[:,[0]]*(1+x[:,[1]])*x[:,[2]] - ranges = [[0,1],[0,1],[0,1]] - - if name == 'II.38.3' or name == 84: - symbol = Y, A, x, d = symbols('Y A x d') - expr = Y*A*x/d - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'II.38.14' or name == 85: - symbol = Y, sigma = symbols('Y sigma') - expr = Y/(2*(1+sigma)) - f = lambda x: x[:,[0]]/(2*(1+x[:,[1]])) - ranges = [[0,1],[0,1]] - - if name == 'III.4.32' or name == 86: - symbol = hbar, omega, kb, T = symbols('hbar omega k_b T') - expr = 1/(exp(hbar*omega/(kb*T))-1) - f = lambda x: 1/(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]))-1) - ranges = [[0.5,1],[0.5,1],[0.5,2],[0.5,2]] - - if name == 'III.4.33' or name == 87: - symbol = hbar, omega, kb, T = symbols('hbar omega k_b T') - expr = hbar*omega/(exp(hbar*omega/(kb*T))-1) - f = lambda x: x[:,[0]]*x[:,[1]]/(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]))-1) - ranges = [[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'III.7.38' or name == 88: - symbol = mu, B, hbar = symbols('mu B hbar') - expr = 2*mu*B/hbar - f = lambda x: 2*x[:,[0]]*x[:,[1]]/x[:,[2]] - ranges = [[0,1],[0,1],[0.5,2]] - - if name == 'III.8.54' or name == 89: - symbol = E, t, hbar = symbols('E t hbar') - expr = sin(E*t/hbar)**2 - f = lambda x: torch.sin(x[:,[0]]*x[:,[1]]/x[:,[2]])**2 - ranges = [[0,2*tpi],[0,1],[0.5,2]] - - if name == 'III.9.52' or name == 90: - symbol = pd, Ef, t, hbar, omega, omega0 = symbols('p_d E_f t hbar omega omega_0') - expr = pd*Ef*t/hbar*sin((omega-omega0)*t/2)**2/((omega-omega0)*t/2)**2 - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]*torch.sin((x[:,[4]]-x[:,[5]])*x[:,[2]]/2)**2/((x[:,[4]]-x[:,[5]])*x[:,[2]]/2)**2 - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0,tpi],[0,tpi]] - - if name == 'III.10.19' or name == 91: - symbol = mu, Bx, By, Bz = symbols('mu B_x B_y B_z') - expr = mu*sqrt(Bx**2+By**2+Bz**2) - f = lambda x: x[:,[0]]*torch.sqrt(x[:,[1]]**2+x[:,[2]]**2+x[:,[3]]**2) - ranges = [[0,1],[0,1],[0,1],[0,1]] - - if name == 'III.12.43' or name == 92: - symbol = n, hbar = symbols('n hbar') - expr = n * hbar - f = lambda x: x[:,[0]]*x[:,[1]] - ranges = [[0,1],[0,1]] - - if name == 'III.13.18' or name == 93: - symbol = E, d, k, hbar = symbols('E d k hbar') - expr = 2*E*d**2*k/hbar - f = lambda x: 2*x[:,[0]]*x[:,[1]]**2*x[:,[2]]/x[:,[3]] - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'III.14.14' or name == 94: - symbol = I0, q, Ve, kb, T = symbols('I_0 q V_e k_b T') - expr = I0 * (exp(q*Ve/(kb*T))-1) - f = lambda x: x[:,[0]]*(torch.exp(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))-1) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'III.15.12' or name == 95: - symbol = U, k, d = symbols('U k d') - expr = 2*U*(1-cos(k*d)) - f = lambda x: 2*x[:,[0]]*(1-torch.cos(x[:,[1]]*x[:,[2]])) - ranges = [[0,1],[0,2*tpi],[0,1]] - - if name == 'III.15.14' or name == 96: - symbol = hbar, E, d = symbols('hbar E d') - expr = hbar**2/(2*E*d**2) - f = lambda x: x[:,[0]]**2/(2*x[:,[1]]*x[:,[2]]**2) - ranges = [[0,1],[0.5,2],[0.5,2]] - - if name == 'III.15.27' or name == 97: - symbol = alpha, n, d = symbols('alpha n d') - expr = 2*pi*alpha/(n*d) - f = lambda x: 2*tpi*x[:,[0]]/(x[:,[1]]*x[:,[2]]) - ranges = [[0,1],[0.5,2],[0.5,2]] - - if name == 'III.17.37' or name == 98: - symbol = beta, alpha, theta = symbols('beta alpha theta') - expr = beta * (1+alpha*cos(theta)) - f = lambda x: x[:,[0]]*(1+x[:,[1]]*torch.cos(x[:,[2]])) - ranges = [[0,1],[0,1],[0,2*tpi]] - - if name == 'III.19.51' or name == 99: - symbol = m, q, eps, hbar, n = symbols('m q epsilon hbar n') - expr = - m * q**4/(2*(4*pi*eps)**2*hbar**2)*1/n**2 - f = lambda x: - x[:,[0]]*x[:,[1]]**4/(2*(4*tpi*x[:,[2]])**2*x[:,[3]]**2)*1/x[:,[4]]**2 - ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2]] - - if name == 'III.21.20' or name == 100: - symbol = rho, q, A, m = symbols('rho q A m') - expr = - rho*q*A/m - f = lambda x: - x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'Rutherforld scattering' or name == 101: - symbol = Z1, Z2, alpha, hbar, c, E, theta = symbols('Z_1 Z_2 alpha hbar c E theta') - expr = (Z1*Z2*alpha*hbar*c/(4*E*sin(theta/2)**2))**2 - f = lambda x: (x[:,[0]]*x[:,[1]]*x[:,[2]]*x[:,[3]]*x[:,[4]]/(4*x[:,[5]]*torch.sin(x[:,[6]]/2)**2))**2 - ranges = [[0,1],[0,1],[0,1],[0,1],[0,1],[0.5,2],[0.1*tpi,0.9*tpi]] - - if name == 'Friedman equation' or name == 102: - symbol = G, rho, kf, c, af = symbols('G rho k_f c a_f') - expr = sqrt(8*pi*G/3*rho-kf*c**2/af**2) - f = lambda x: torch.sqrt(8*tpi*x[:,[0]]/3*x[:,[1]] - x[:,[2]]*x[:,[3]]**2/x[:,[4]]**2) - ranges = [[1,2],[1,2],[0,1],[0,1],[1,2]] - - if name == 'Compton scattering' or name == 103: - symbol = E, m, c, theta = symbols('E m c theta') - expr = E/(1+E/(m*c**2)*(1-cos(theta))) - f = lambda x: x[:,[0]]/(1+x[:,[0]]/(x[:,[1]]*x[:,[2]]**2)*(1-torch.cos(x[:,[3]]))) - ranges = [[0,1],[0.5,2],[0.5,2],[0,2*tpi]] - - if name == 'Radiated gravitational wave power' or name == 104: - symbol = G, c, m1, m2, r = symbols('G c m_1 m_2 r') - expr = -32/5*G**4/c**5*(m1*m2)**2*(m1+m2)/r**5 - f = lambda x: -32/5*x[:,[0]]**4/x[:,[1]]**5*(x[:,[2]]*x[:,[3]])**2*(x[:,[2]]+x[:,[3]])/x[:,[4]]**5 - ranges = [[0,1],[0.5,2],[0,1],[0,1],[0.5,2]] - - if name == 'Relativistic aberration' or name == 105: - symbol = theta2, v, c = symbols('theta_2 v c') - expr = acos((cos(theta2)-v/c)/(1-v/c*cos(theta2))) - f = lambda x: torch.arccos((torch.cos(x[:,[0]])-x[:,[1]]/x[:,[2]])/(1-x[:,[1]]/x[:,[2]]*torch.cos(x[:,[0]]))) - ranges = [[0,tpi],[0,1],[1,2]] - - if name == 'N-slit diffraction' or name == 106: - symbol = I0, alpha, delta, N = symbols('I_0 alpha delta N') - expr = I0 * (sin(alpha/2)/(alpha/2)*sin(N*delta/2)/sin(delta/2))**2 - f = lambda x: x[:,[0]] * (torch.sin(x[:,[1]]/2)/(x[:,[1]]/2)*torch.sin(x[:,[3]]*x[:,[2]]/2)/torch.sin(x[:,[2]]/2))**2 - ranges = [[0,1],[0.1*tpi,0.9*tpi],[0.1*tpi,0.9*tpi],[0.5,1]] - - if name == 'Goldstein 3.16' or name == 107: - symbol = m, E, U, L, r = symbols('m E U L r') - expr = sqrt(2/m*(E-U-L**2/(2*m*r**2))) - f = lambda x: torch.sqrt(2/x[:,[0]]*(x[:,[1]]-x[:,[2]]-x[:,[3]]**2/(2*x[:,[0]]*x[:,[4]]**2))) - ranges = [[1,2],[2,3],[0,1],[0,1],[1,2]] - - if name == 'Goldstein 3.55' or name == 108: - symbol = m, kG, L, E, theta1, theta2 = symbols('m k_G L E theta_1 theta_2') - expr = m*kG/L**2*(1+sqrt(1+2*E*L**2/(m*kG**2))*cos(theta1-theta2)) - f = lambda x: x[:,[0]]*x[:,[1]]/x[:,[2]]**2*(1+torch.sqrt(1+2*x[:,[3]]*x[:,[2]]**2/(x[:,[0]]*x[:,[1]]**2))*torch.cos(x[:,[4]]-x[:,[5]])) - ranges = [[0.5,2],[0.5,2],[0.5,2],[0,1],[0,2*tpi],[0,2*tpi]] - - if name == 'Goldstein 3.64 (ellipse)' or name == 109: - symbol = d, alpha, theta1, theta2 = symbols('d alpha theta_1 theta_2') - expr = d*(1-alpha**2)/(1+alpha*cos(theta2-theta1)) - f = lambda x: x[:,[0]]*(1-x[:,[1]]**2)/(1+x[:,[1]]*torch.cos(x[:,[2]]-x[:,[3]])) - ranges = [[0,1],[0,0.9],[0,2*tpi],[0,2*tpi]] - - if name == 'Goldstein 3.74 (Kepler)' or name == 110: - symbol = d, G, m1, m2 = symbols('d G m_1 m_2') - expr = 2*pi*d**(3/2)/sqrt(G*(m1+m2)) - f = lambda x: 2*tpi*x[:,[0]]**(3/2)/torch.sqrt(x[:,[1]]*(x[:,[2]]+x[:,[3]])) - ranges = [[0,1],[0.5,2],[0.5,2],[0.5,2]] - - if name == 'Goldstein 3.99' or name == 111: - symbol = eps, E, L, m, Z1, Z2, q = symbols('epsilon E L m Z_1 Z_2 q') - expr = sqrt(1+2*eps**2*E*L**2/(m*(Z1*Z2*q**2)**2)) - f = lambda x: torch.sqrt(1+2*x[:,[0]]**2*x[:,[1]]*x[:,[2]]**2/(x[:,[3]]*(x[:,[4]]*x[:,[5]]*x[:,[6]]**2)**2)) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2],[0.5,2]] - - if name == 'Goldstein 8.56' or name == 112: - symbol = p, q, A, c, m, Ve = symbols('p q A c m V_e') - expr = sqrt((p-q*A)**2*c**2+m**2*c**4) + q*Ve - f = lambda x: torch.sqrt((x[:,[0]]-x[:,[1]]*x[:,[2]])**2*x[:,[3]]**2+x[:,[4]]**2*x[:,[3]]**4) + x[:,[1]]*x[:,[5]] - ranges = [0,1] - - if name == 'Goldstein 12.80' or name == 113: - symbol = m, p, omega, x, alpha, y = symbols('m p omega x alpha y') - expr = 1/(2*m)*(p**2+m**2*omega**2*x**2*(1+alpha*y/x)) - f = lambda x: 1/(2*x[:,[0]]) * (x[:,[1]]**2+x[:,[0]]**2*x[:,[2]]**2*x[:,[3]]**2*(1+x[:,[4]]*x[:,[3]]/x[:,[5]])) - ranges = [[0.5,2],[0,1],[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'Jackson 2.11' or name == 114: - symbol = q, eps, y, Ve, d = symbols('q epsilon y V_e d') - expr = q/(4*pi*eps*y**2)*(4*pi*eps*Ve*d-q*d*y**3/(y**2-d**2)**2) - f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,x[:,[2]]]**2)*(4*tpi*x[:,[1]]*x[:,[3]]*x[:,[4]]-x[:,[0]]*x[:,[4]]*x[:,[2]]**3/(x[:,[2]]**2-x[:,[4]]**2)**2) - ranges = [[0,1],[0.5,2],[1,2],[0,1],[0,1]] - - if name == 'Jackson 3.45' or name == 115: - symbol = q, r, d, alpha = symbols('q r d alpha') - expr = q/sqrt(r**2+d**2-2*d*r*cos(alpha)) - f = lambda x: x[:,[0]]/torch.sqrt(x[:,[1]]**2+x[:,[2]]**2-2*x[:,[1]]*x[:,[2]]*torch.cos(x[:,[3]])) - ranges = [[0,1],[0,1],[0,1],[0,2*tpi]] - - if name == 'Jackson 4.60' or name == 116: - symbol = Ef, theta, alpha, d, r = symbols('E_f theta alpha d r') - expr = Ef * cos(theta) * ((alpha-1)/(alpha+2) * d**3/r**2 - r) - f = lambda x: x[:,[0]] * torch.cos(x[:,[1]]) * ((x[:,[2]]-1)/(x[:,[2]]+2) * x[:,[3]]**3/x[:,[4]]**2 - x[:,[4]]) - ranges = [[0,1],[0,2*tpi],[0,2],[0,1],[0.5,2]] - - if name == 'Jackson 11.38 (Doppler)' or name == 117: - symbol = omega, v, c, theta = symbols('omega v c theta') - expr = sqrt(1-v**2/c**2)/(1+v/c*cos(theta))*omega - f = lambda x: torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)/(1+x[:,[1]]/x[:,[2]]*torch.cos(x[:,[3]]))*x[:,[0]] - ranges = [[0,1],[0,1],[1,2],[0,2*tpi]] - - if name == 'Weinberg 15.2.1' or name == 118: - symbol = G, c, kf, af, H = symbols('G c k_f a_f H') - expr = 3/(8*pi*G)*(c**2*kf/af**2+H**2) - f = lambda x: 3/(8*tpi*x[:,[0]])*(x[:,[1]]**2*x[:,[2]]/x[:,[3]]**2+x[:,[4]]**2) - ranges = [[0.5,2],[0,1],[0,1],[0.5,2],[0,1]] - - if name == 'Weinberg 15.2.2' or name == 119: - symbol = G, c, kf, af, H, alpha = symbols('G c k_f a_f H alpha') - expr = -1/(8*pi*G)*(c**4*kf/af**2+c**2*H**2*(1-2*alpha)) - f = lambda x: -1/(8*tpi*x[:,[0]])*(x[:,[1]]**4*x[:,[2]]/x[:,[3]]**2 + x[:,[1]]**2*x[:,[4]]**2*(1-2*x[:,[5]])) - ranges = [[0.5,2],[0,1],[0,1],[0.5,2],[0,1],[0,1]] - - if name == 'Schwarz 13.132 (Klein-Nishina)' or name == 120: - symbol = alpha, hbar, m, c, omega0, omega, theta = symbols('alpha hbar m c omega_0 omega theta') - expr = pi*alpha**2*hbar**2/m**2/c**2*(omega0/omega)**2*(omega0/omega+omega/omega0-sin(theta)**2) - f = lambda x: tpi*x[:,[0]]**2*x[:,[1]]**2/x[:,[2]]**2/x[:,[3]]**2*(x[:,[4]]/x[:,[5]])**2*(x[:,[4]]/x[:,[5]]+x[:,[5]]/x[:,[4]]-torch.sin(x[:,[6]])**2) - ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2],[0.5,2],[0,2*tpi]] - - return symbol, expr, f, ranges \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/hypothesis-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/hypothesis-checkpoint.py deleted file mode 100644 index 4850f509849c9efe21437b30b9ca2f220bf38181..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/hypothesis-checkpoint.py +++ /dev/null @@ -1,695 +0,0 @@ -import numpy as np -import torch -from sklearn.linear_model import LinearRegression -from sympy.utilities.lambdify import lambdify -from sklearn.cluster import AgglomerativeClustering -from .utils import batch_jacobian, batch_hessian -from functools import reduce -from kan.utils import batch_jacobian, batch_hessian -import copy -import matplotlib.pyplot as plt -import sympy -from sympy.printing import latex - - -def detect_separability(model, x, mode='add', score_th=1e-2, res_th=1e-2, n_clusters=None, bias=0., verbose=False): - ''' - detect function separability - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - mode : str - mode = 'add' or mode = 'mul' - score_th : float - threshold of score - res_th : float - threshold of residue - n_clusters : None or int - the number of clusters - bias : float - bias (for multiplicative separability) - verbose : bool - - Returns: - -------- - results (dictionary) - - Example1 - -------- - >>> from kan.hypothesis import * - >>> model = lambda x: x[:,[0]] ** 2 + torch.exp(x[:,[1]]+x[:,[2]]) - >>> x = torch.normal(0,1,size=(100,3)) - >>> detect_separability(model, x, mode='add') - - Example2 - -------- - >>> from kan.hypothesis import * - >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]]) - >>> x = torch.normal(0,1,size=(100,3)) - >>> detect_separability(model, x, mode='mul') - ''' - results = {} - - if mode == 'add': - hessian = batch_hessian(model, x) - elif mode == 'mul': - compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F) - hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x) - - std = torch.std(x, dim=0) - hessian_normalized = hessian * std[None,:] * std[:,None] - score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0] - results['hessian'] = score_mat - - dist_hard = (score_mat < score_th).float() - - if isinstance(n_clusters, int): - n_cluster_try = [n_clusters, n_clusters] - elif isinstance(n_clusters, list): - n_cluster_try = n_clusters - else: - n_cluster_try = [1,x.shape[1]] - - n_cluster_try = list(range(n_cluster_try[0], n_cluster_try[1]+1)) - - for n_cluster in n_cluster_try: - - clustering = AgglomerativeClustering( - metric='precomputed', - n_clusters=n_cluster, - linkage='complete', - ).fit(dist_hard) - - labels = clustering.labels_ - - groups = [list(np.where(labels == i)[0]) for i in range(n_cluster)] - blocks = [torch.sum(score_mat[groups[i]][:,groups[i]]) for i in range(n_cluster)] - block_sum = torch.sum(torch.stack(blocks)) - total_sum = torch.sum(score_mat) - residual_sum = total_sum - block_sum - residual_ratio = residual_sum / total_sum - - if verbose == True: - print(f'n_group={n_cluster}, residual_ratio={residual_ratio}') - - if residual_ratio < res_th: - results['n_groups'] = n_cluster - results['labels'] = list(labels) - results['groups'] = groups - - if results['n_groups'] > 1: - print(f'{mode} separability detected') - else: - print(f'{mode} separability not detected') - - return results - - -def batch_grad_normgrad(model, x, group, create_graph=False): - # x in shape (Batch, Length) - group_A = group - group_B = list(set(range(x.shape[1])) - set(group)) - - def jac(x): - input_grad = batch_jacobian(model, x, create_graph=True) - input_grad_A = input_grad[:,group_A] - norm = torch.norm(input_grad_A, dim=1, keepdim=True) + 1e-6 - input_grad_A_normalized = input_grad_A/norm - return input_grad_A_normalized - - def _jac_sum(x): - return jac(x).sum(dim=0) - - return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)[:,:,group_B] - - -def get_dependence(model, x, group): - group_A = group - group_B = list(set(range(x.shape[1])) - set(group)) - grad_normgrad = batch_grad_normgrad(model, x, group=group) - std = torch.std(x, dim=0) - dependence = grad_normgrad * std[None,group_A,None] * std[None,None,group_B] - dependence = torch.median(torch.abs(dependence), dim=0)[0] - return dependence - -def test_symmetry(model, x, group, dependence_th=1e-3): - ''' - detect function separability - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - group : a list of indices - dependence_th : float - threshold of dependence - - Returns: - -------- - bool - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]]) - >>> x = torch.normal(0,1,size=(100,3)) - >>> print(test_symmetry(model, x, [1,2])) # True - >>> print(test_symmetry(model, x, [0,2])) # False - ''' - if len(group) == x.shape[1] or len(group) == 0: - return True - - dependence = get_dependence(model, x, group) - max_dependence = torch.max(dependence) - return max_dependence < dependence_th - - -def test_separability(model, x, groups, mode='add', threshold=1e-2, bias=0): - ''' - test function separability - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - mode : str - mode = 'add' or mode = 'mul' - score_th : float - threshold of score - res_th : float - threshold of residue - bias : float - bias (for multiplicative separability) - verbose : bool - - Returns: - -------- - bool - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]]) - >>> x = torch.normal(0,1,size=(100,3)) - >>> print(test_separability(model, x, [[0],[1,2]], mode='mul')) # True - >>> print(test_separability(model, x, [[0],[1,2]], mode='add')) # False - ''' - if mode == 'add': - hessian = batch_hessian(model, x) - elif mode == 'mul': - compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F) - hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x) - - std = torch.std(x, dim=0) - hessian_normalized = hessian * std[None,:] * std[:,None] - score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0] - - sep_bool = True - - # internal test - n_groups = len(groups) - for i in range(n_groups): - for j in range(i+1, n_groups): - sep_bool *= torch.max(score_mat[groups[i]][:,groups[j]]) < threshold - - # external test - group_id = [x for xs in groups for x in xs] - nongroup_id = list(set(range(x.shape[1])) - set(group_id)) - if len(nongroup_id) > 0 and len(group_id) > 0: - sep_bool *= torch.max(score_mat[group_id][:,nongroup_id]) < threshold - - return sep_bool - -def test_general_separability(model, x, groups, threshold=1e-2): - ''' - test function separability - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - mode : str - mode = 'add' or mode = 'mul' - score_th : float - threshold of score - res_th : float - threshold of residue - bias : float - bias (for multiplicative separability) - verbose : bool - - Returns: - -------- - bool - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]**2+x[:,[2]]**2)**2 - >>> x = torch.normal(0,1,size=(100,3)) - >>> print(test_general_separability(model, x, [[1],[0,2]])) # False - >>> print(test_general_separability(model, x, [[0],[1,2]])) # True - ''' - grad = batch_jacobian(model, x) - - gensep_bool = True - - n_groups = len(groups) - for i in range(n_groups): - for j in range(i+1,n_groups): - group_A = groups[i] - group_B = groups[j] - for member_A in group_A: - for member_B in group_B: - def func(x): - grad = batch_jacobian(model, x, create_graph=True) - return grad[:,[member_B]]/grad[:,[member_A]] - # test if func is multiplicative separable - gensep_bool *= test_separability(func, x, groups, mode='mul', threshold=threshold) - return gensep_bool - - -def get_molecule(model, x, sym_th=1e-3, verbose=True): - ''' - how variables are combined hierarchically - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - sym_th : float - threshold of symmetry - verbose : bool - - Returns: - -------- - list - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2 - >>> x = torch.normal(0,1,size=(100,8)) - >>> get_molecule(model, x, verbose=False) - [[[0], [1], [2], [3], [4], [5], [6], [7]], - [[0, 1], [2, 3], [4, 5], [6, 7]], - [[0, 1, 2, 3], [4, 5, 6, 7]], - [[0, 1, 2, 3, 4, 5, 6, 7]]] - ''' - n = x.shape[1] - atoms = [[i] for i in range(n)] - molecules = [] - moleculess = [copy.deepcopy(atoms)] - already_full = False - n_layer = 0 - last_n_molecule = n - - while True: - - - pointer = 0 - current_molecule = [] - remove_atoms = [] - n_atom = 0 - - while len(atoms) > 0: - - # assemble molecule - atom = atoms[pointer] - if verbose: - print(current_molecule) - print(atom) - - if len(current_molecule) == 0: - full = False - current_molecule += atom - remove_atoms.append(atom) - n_atom += 1 - else: - # try assemble the atom to the molecule - if len(current_molecule+atom) == x.shape[1] and already_full == False and n_atom > 1 and n_layer > 0: - full = True - already_full = True - else: - full = False - if test_symmetry(model, x, current_molecule+atom, dependence_th=sym_th): - current_molecule += atom - remove_atoms.append(atom) - n_atom += 1 - - pointer += 1 - - if pointer == len(atoms) or full: - molecules.append(current_molecule) - if full: - molecules.append(atom) - remove_atoms.append(atom) - # remove molecules from atoms - for atom in remove_atoms: - atoms.remove(atom) - current_molecule = [] - remove_atoms = [] - pointer = 0 - - # if not making progress, terminate - if len(molecules) == last_n_molecule: - def flatten(xss): - return [x for xs in xss for x in xs] - moleculess.append([flatten(molecules)]) - break - else: - moleculess.append(copy.deepcopy(molecules)) - - last_n_molecule = len(molecules) - - if len(molecules) == 1: - break - - atoms = molecules - molecules = [] - - n_layer += 1 - - #print(n_layer, atoms) - - - # sort - depth = len(moleculess) - 1 - - for l in list(range(depth,0,-1)): - - molecules_sorted = [] - molecules_l = moleculess[l] - molecules_lm1 = moleculess[l-1] - - - for molecule_l in molecules_l: - start = 0 - for i in range(1,len(molecule_l)+1): - if molecule_l[start:i] in molecules_lm1: - - molecules_sorted.append(molecule_l[start:i]) - start = i - - moleculess[l-1] = molecules_sorted - - return moleculess - - -def get_tree_node(model, x, moleculess, sep_th=1e-2, skip_test=True): - ''' - get tree nodes - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - sep_th : float - threshold of separability - skip_test : bool - if True, don't test the property of each module (to save time) - - Returns: - -------- - arities : list of numbers - properties : list of strings - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2 - >>> x = torch.normal(0,1,size=(100,8)) - >>> moleculess = get_molecule(model, x, verbose=False) - >>> get_tree_node(model, x, moleculess, skip_test=False) - ''' - arities = [] - properties = [] - - depth = len(moleculess) - 1 - - for l in range(depth): - molecules_l = copy.deepcopy(moleculess[l]) - molecules_lp1 = copy.deepcopy(moleculess[l+1]) - arity_l = [] - property_l = [] - - for molecule in molecules_lp1: - start = 0 - arity = 0 - groups = [] - for i in range(1,len(molecule)+1): - if molecule[start:i] in molecules_l: - groups.append(molecule[start:i]) - start = i - arity += 1 - arity_l.append(arity) - - if arity == 1: - property = 'Id' - else: - property = '' - # test property - if skip_test: - gensep_bool = False - else: - gensep_bool = test_general_separability(model, x, groups, threshold=sep_th) - - if gensep_bool: - property = 'GS' - if l == depth - 1: - if skip_test: - add_bool = False - mul_bool = False - else: - add_bool = test_separability(model, x, groups, mode='add', threshold=sep_th) - mul_bool = test_separability(model, x, groups, mode='mul', threshold=sep_th) - if add_bool: - property = 'Add' - if mul_bool: - property = 'Mul' - - - property_l.append(property) - - - arities.append(arity_l) - properties.append(property_l) - - return arities, properties - - -def plot_tree(model, x, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False): - ''' - get tree graph - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - in_var : list of symbols - input variables - style : str - 'tree' or 'box' - sym_th : float - threshold of symmetry - sep_th : float - threshold of separability - skip_sep_test : bool - if True, don't test the property of each module (to save time) - verbose : bool - - Returns: - -------- - a tree graph - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2 - >>> x = torch.normal(0,1,size=(100,8)) - >>> plot_tree(model, x) - ''' - moleculess = get_molecule(model, x, sym_th=sym_th, verbose=verbose) - arities, properties = get_tree_node(model, x, moleculess, sep_th=sep_th, skip_test=skip_sep_test) - - n = x.shape[1] - var = None - - in_vars = [] - - if in_var == None: - for ii in range(1, n + 1): - exec(f"x{ii} = sympy.Symbol('x_{ii}')") - exec(f"in_vars.append(x{ii})") - elif type(var[0]) == Symbol: - in_vars = var - else: - in_vars = [sympy.symbols(var_) for var_ in var] - - - def flatten(xss): - return [x for xs in xss for x in xs] - - def myrectangle(center_x, center_y, width_x, width_y): - plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y + width_y/2, center_y + width_y/2], color='k') # up - plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y - width_y/2], color='k') # down - plt.plot([center_x - width_x/2, center_x - width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left - plt.plot([center_x + width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left - - depth = len(moleculess) - - delta = 1/n - a = 0.3 - b = 0.15 - y0 = 0.5 - - - # draw rectangles - for l in range(depth-1): - molecules = moleculess[l+1] - n_molecule = len(molecules) - - centers = [] - - acc_arity = 0 - - for i in range(n_molecule): - start_id = len(flatten(molecules[:i])) - end_id = len(flatten(molecules[:i+1])) - - center_x = (start_id + (end_id - 1 - start_id)/2) * delta + delta/2 - center_y = (l+1/2)*y0 - width_x = (end_id - start_id - 1 + 2*a)*delta - width_y = 2*b - - # add text (numbers) on rectangles - if style == 'box': - myrectangle(center_x, center_y, width_x, width_y) - plt.text(center_x, center_y, properties[l][i], fontsize=15, horizontalalignment='center', - verticalalignment='center') - elif style == 'tree': - # if 'GS', no rectangle, n=arity tilted lines - # if 'Id', no rectangle, n=arity vertical lines - # if 'Add' or 'Mul'. rectangle, "+" or "x" - # if '', rectangle - property = properties[l][i] - if property == 'GS' or property == 'Add' or property == 'Mul': - color = 'blue' - arity = arities[l][i] - for j in range(arity): - - if l == 0: - # x = (start_id + j) * delta + delta/2, center_x - # y = center_y - b, center_y + b - plt.plot([(start_id + j) * delta + delta/2, center_x], [center_y - b, center_y + b], color=color) - else: - # x = last_centers[acc_arity:acc_arity+arity], center_x - # y = center_y - b, center_y + b - plt.plot([last_centers[acc_arity+j], center_x], [center_y - b, center_y + b], color=color) - - acc_arity += arity - - if property == 'Add' or property == 'Mul': - if property == 'Add': - symbol = '+' - else: - symbol = '*' - - plt.text(center_x, center_y + b, symbol, horizontalalignment='center', - verticalalignment='center', color='red', fontsize=40) - if property == 'Id': - plt.plot([center_x, center_x], [center_y-width_y/2, center_y+width_y/2], color='black') - - if property == '': - myrectangle(center_x, center_y, width_x, width_y) - - - - # connections to the next layer - plt.plot([center_x, center_x], [center_y+width_y/2, center_y+y0-width_y/2], color='k') - centers.append(center_x) - last_centers = copy.deepcopy(centers) - - # connections from input variables to the first layer - for i in range(n): - x_ = (i + 1/2) * delta - # connections to the next layer - plt.plot([x_, x_], [0, y0/2-width_y/2], color='k') - plt.text(x_, -0.05*(depth-1), f'${latex(in_vars[moleculess[0][i][0]])}$', fontsize=20, horizontalalignment='center') - plt.xlim(0,1) - #plt.ylim(0,1); - plt.axis('off'); - plt.show() - - -def test_symmetry_var(model, x, input_vars, symmetry_var): - ''' - test symmetry - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - input_vars : list of sympy symbols - symmetry_var : sympy expression - - Returns: - -------- - cosine similarity - - Example - ------- - >>> from kan.hypothesis import * - >>> from sympy import * - >>> model = lambda x: x[:,[0]] * (x[:,[1]] + x[:,[2]]) - >>> x = torch.normal(0,1,size=(100,8)) - >>> input_vars = a, b, c = symbols('a b c') - >>> symmetry_var = b + c - >>> test_symmetry_var(model, x, input_vars, symmetry_var); - >>> symmetry_var = b * c - >>> test_symmetry_var(model, x, input_vars, symmetry_var); - ''' - orig_vars = input_vars - sym_var = symmetry_var - - # gradients wrt to input (model) - input_grad = batch_jacobian(model, x) - - # gradients wrt to input (symmetry var) - func = lambdify(orig_vars, sym_var,'numpy') # returns a numpy-ready function - - func2 = lambda x: func(*[x[:,[i]] for i in range(len(orig_vars))]) - sym_grad = batch_jacobian(func2, x) - - # get id - idx = [] - sym_symbols = list(sym_var.free_symbols) - for sym_symbol in sym_symbols: - for j in range(len(orig_vars)): - if sym_symbol == orig_vars[j]: - idx.append(j) - - input_grad_part = input_grad[:,idx] - sym_grad_part = sym_grad[:,idx] - - cossim = torch.abs(torch.sum(input_grad_part * sym_grad_part, dim=1)/(torch.norm(input_grad_part, dim=1)*torch.norm(sym_grad_part, dim=1))) - - ratio = torch.sum(cossim > 0.9)/len(cossim) - - print(f'{100*ratio}% data have more than 0.9 cosine similarity') - if ratio > 0.9: - print('suggesting symmetry') - else: - print('not suggesting symmetry') - - return cossim \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/spline-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/spline-checkpoint.py deleted file mode 100644 index 6953bf08193073301dff651a9d78ad9b40c5fd53..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/spline-checkpoint.py +++ /dev/null @@ -1,144 +0,0 @@ -import torch - - -def B_batch(x, grid, k=0, extend=True, device='cpu'): - ''' - evaludate x on B-spline bases - - Args: - ----- - x : 2D torch.tensor - inputs, shape (number of splines, number of samples) - grid : 2D torch.tensor - grids, shape (number of splines, number of grid points) - k : int - the piecewise polynomial order of splines. - extend : bool - If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True - device : str - devicde - - Returns: - -------- - spline values : 3D torch.tensor - shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order. - - Example - ------- - >>> from kan.spline import B_batch - >>> x = torch.rand(100,2) - >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11) - >>> B_batch(x, grid, k=3).shape - ''' - - x = x.unsqueeze(dim=2) - grid = grid.unsqueeze(dim=0) - - if k == 0: - value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) - else: - B_km1 = B_batch(x[:,:,0], grid=grid[0], k=k - 1) - - value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + ( - grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:] - - # in case grid is degenerate - value = torch.nan_to_num(value) - return value - - - -def coef2curve(x_eval, grid, coef, k, device="cpu"): - ''' - converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis). - - Args: - ----- - x_eval : 2D torch.tensor - shape (batch, in_dim) - grid : 2D torch.tensor - shape (in_dim, G+2k). G: the number of grid intervals; k: spline order. - coef : 3D torch.tensor - shape (in_dim, out_dim, G+k) - k : int - the piecewise polynomial order of splines. - device : str - devicde - - Returns: - -------- - y_eval : 3D torch.tensor - shape (batch, in_dim, out_dim) - - ''' - - b_splines = B_batch(x_eval, grid, k=k) - y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device)) - - return y_eval - - -def curve2coef(x_eval, y_eval, grid, k): - ''' - converting B-spline curves to B-spline coefficients using least squares. - - Args: - ----- - x_eval : 2D torch.tensor - shape (batch, in_dim) - y_eval : 3D torch.tensor - shape (batch, in_dim, out_dim) - grid : 2D torch.tensor - shape (in_dim, grid+2*k) - k : int - spline order - lamb : float - regularized least square lambda - - Returns: - -------- - coef : 3D torch.tensor - shape (in_dim, out_dim, G+k) - ''' - #print('haha', x_eval.shape, y_eval.shape, grid.shape) - batch = x_eval.shape[0] - in_dim = x_eval.shape[1] - out_dim = y_eval.shape[2] - n_coef = grid.shape[1] - k - 1 - mat = B_batch(x_eval, grid, k) - mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef) - #print('mat', mat.shape) - y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3) - #print('y_eval', y_eval.shape) - device = mat.device - - #coef = torch.linalg.lstsq(mat, y_eval, driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0] - try: - coef = torch.linalg.lstsq(mat, y_eval).solution[:,:,:,0] - except: - print('lstsq failed') - - # manual psuedo-inverse - '''lamb=1e-8 - XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat) - Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval) - n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2] - identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device) - A = XtX + lamb * identity - B = Xty - coef = (A.pinverse() @ B)[:,:,:,0]''' - - return coef - - -def extend_grid(grid, k_extend=0): - ''' - extend grid - ''' - h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) - - for i in range(k_extend): - grid = torch.cat([grid[:, [0]] - h, grid], dim=1) - grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) - - return grid \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/utils-checkpoint.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/utils-checkpoint.py deleted file mode 100644 index abb4d558ba0b8bfd92356f4d41fd1cfcb5e7bf55..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/.ipynb_checkpoints/utils-checkpoint.py +++ /dev/null @@ -1,594 +0,0 @@ -import numpy as np -import torch -from sklearn.linear_model import LinearRegression -import sympy -import yaml -from sympy.utilities.lambdify import lambdify -import re - -# sigmoid = sympy.Function('sigmoid') -# name: (torch implementation, sympy implementation) - -# singularity protection functions -f_inv = lambda x, y_th: ((x_th := 1/y_th), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x) * (torch.abs(x) >= x_th)) -f_inv2 = lambda x, y_th: ((x_th := 1/y_th**(1/2)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**2) * (torch.abs(x) >= x_th)) -f_inv3 = lambda x, y_th: ((x_th := 1/y_th**(1/3)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**3) * (torch.abs(x) >= x_th)) -f_inv4 = lambda x, y_th: ((x_th := 1/y_th**(1/4)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**4) * (torch.abs(x) >= x_th)) -f_inv5 = lambda x, y_th: ((x_th := 1/y_th**(1/5)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**5) * (torch.abs(x) >= x_th)) -f_sqrt = lambda x, y_th: ((x_th := 1/y_th**2), x_th/y_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(torch.sqrt(torch.abs(x))*torch.sign(x)) * (torch.abs(x) >= x_th)) -f_power1d5 = lambda x, y_th: torch.abs(x)**1.5 -f_invsqrt = lambda x, y_th: ((x_th := 1/y_th**2), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/torch.sqrt(torch.abs(x))) * (torch.abs(x) >= x_th)) -f_log = lambda x, y_th: ((x_th := torch.e**(-y_th)), - y_th * (torch.abs(x) < x_th) + torch.nan_to_num(torch.log(torch.abs(x))) * (torch.abs(x) >= x_th)) -f_tan = lambda x, y_th: ((clip := x % torch.pi), (delta := torch.pi/2-torch.arctan(y_th)), - y_th/delta * (clip - torch.pi/2) * (torch.abs(clip - torch.pi/2) < delta) + torch.nan_to_num(torch.tan(clip)) * (torch.abs(clip - torch.pi/2) >= delta)) -f_arctanh = lambda x, y_th: ((delta := 1-torch.tanh(y_th) + 1e-4), y_th * torch.sign(x) * (torch.abs(x) > 1 - delta) + torch.nan_to_num(torch.arctanh(x)) * (torch.abs(x) <= 1 - delta)) -f_arcsin = lambda x, y_th: ((), torch.pi/2 * torch.sign(x) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arcsin(x)) * (torch.abs(x) <= 1)) -f_arccos = lambda x, y_th: ((), torch.pi/2 * (1-torch.sign(x)) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arccos(x)) * (torch.abs(x) <= 1)) -f_exp = lambda x, y_th: ((x_th := torch.log(y_th)), y_th * (x > x_th) + torch.exp(x) * (x <= x_th)) - -SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)), - 'x^2': (lambda x: x**2, lambda x: x**2, 2, lambda x, y_th: ((), x**2)), - 'x^3': (lambda x: x**3, lambda x: x**3, 3, lambda x, y_th: ((), x**3)), - 'x^4': (lambda x: x**4, lambda x: x**4, 3, lambda x, y_th: ((), x**4)), - 'x^5': (lambda x: x**5, lambda x: x**5, 3, lambda x, y_th: ((), x**5)), - '1/x': (lambda x: 1/x, lambda x: 1/x, 2, f_inv), - '1/x^2': (lambda x: 1/x**2, lambda x: 1/x**2, 2, f_inv2), - '1/x^3': (lambda x: 1/x**3, lambda x: 1/x**3, 3, f_inv3), - '1/x^4': (lambda x: 1/x**4, lambda x: 1/x**4, 4, f_inv4), - '1/x^5': (lambda x: 1/x**5, lambda x: 1/x**5, 5, f_inv5), - 'sqrt': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt), - 'x^0.5': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt), - 'x^1.5': (lambda x: torch.sqrt(x)**3, lambda x: sympy.sqrt(x)**3, 4, f_power1d5), - '1/sqrt(x)': (lambda x: 1/torch.sqrt(x), lambda x: 1/sympy.sqrt(x), 2, f_invsqrt), - '1/x^0.5': (lambda x: 1/torch.sqrt(x), lambda x: 1/sympy.sqrt(x), 2, f_invsqrt), - 'exp': (lambda x: torch.exp(x), lambda x: sympy.exp(x), 2, f_exp), - 'log': (lambda x: torch.log(x), lambda x: sympy.log(x), 2, f_log), - 'abs': (lambda x: torch.abs(x), lambda x: sympy.Abs(x), 3, lambda x, y_th: ((), torch.abs(x))), - 'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x), 2, lambda x, y_th: ((), torch.sin(x))), - 'cos': (lambda x: torch.cos(x), lambda x: sympy.cos(x), 2, lambda x, y_th: ((), torch.cos(x))), - 'tan': (lambda x: torch.tan(x), lambda x: sympy.tan(x), 3, f_tan), - 'tanh': (lambda x: torch.tanh(x), lambda x: sympy.tanh(x), 3, lambda x, y_th: ((), torch.tanh(x))), - 'sgn': (lambda x: torch.sign(x), lambda x: sympy.sign(x), 3, lambda x, y_th: ((), torch.sign(x))), - 'arcsin': (lambda x: torch.arcsin(x), lambda x: sympy.asin(x), 4, f_arcsin), - 'arccos': (lambda x: torch.arccos(x), lambda x: sympy.acos(x), 4, f_arccos), - 'arctan': (lambda x: torch.arctan(x), lambda x: sympy.atan(x), 4, lambda x, y_th: ((), torch.arctan(x))), - 'arctanh': (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x), 4, f_arctanh), - '0': (lambda x: x*0, lambda x: x*0, 0, lambda x, y_th: ((), x*0)), - 'gaussian': (lambda x: torch.exp(-x**2), lambda x: sympy.exp(-x**2), 3, lambda x, y_th: ((), torch.exp(-x**2))), - #'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5), - #'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4), - #'relu': (lambda x: torch.relu(x), relu), -} - -def create_dataset(f, - n_var=2, - f_mode = 'col', - ranges = [-1,1], - train_num=1000, - test_num=1000, - normalize_input=False, - normalize_label=False, - device='cpu', - seed=0): - ''' - create dataset - - Args: - ----- - f : function - the symbolic formula used to create the synthetic dataset - ranges : list or np.array; shape (2,) or (n_var, 2) - the range of input variables. Default: [-1,1]. - train_num : int - the number of training samples. Default: 1000. - test_num : int - the number of test samples. Default: 1000. - normalize_input : bool - If True, apply normalization to inputs. Default: False. - normalize_label : bool - If True, apply normalization to labels. Default: False. - device : str - device. Default: 'cpu'. - seed : int - random seed. Default: 0. - - Returns: - -------- - dataset : dic - Train/test inputs/labels are dataset['train_input'], dataset['train_label'], - dataset['test_input'], dataset['test_label'] - - Example - ------- - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=2, train_num=100) - >>> dataset['train_input'].shape - torch.Size([100, 2]) - ''' - - np.random.seed(seed) - torch.manual_seed(seed) - - if len(np.array(ranges).shape) == 1: - ranges = np.array(ranges * n_var).reshape(n_var,2) - else: - ranges = np.array(ranges) - - - train_input = torch.zeros(train_num, n_var) - test_input = torch.zeros(test_num, n_var) - for i in range(n_var): - train_input[:,i] = torch.rand(train_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0] - test_input[:,i] = torch.rand(test_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0] - - if f_mode == 'col': - train_label = f(train_input) - test_label = f(test_input) - elif f_mode == 'row': - train_label = f(train_input.T) - test_label = f(test_input.T) - else: - print(f'f_mode {f_mode} not recognized') - - # if has only 1 dimension - if len(train_label.shape) == 1: - train_label = train_label.unsqueeze(dim=1) - test_label = test_label.unsqueeze(dim=1) - - def normalize(data, mean, std): - return (data-mean)/std - - if normalize_input == True: - mean_input = torch.mean(train_input, dim=0, keepdim=True) - std_input = torch.std(train_input, dim=0, keepdim=True) - train_input = normalize(train_input, mean_input, std_input) - test_input = normalize(test_input, mean_input, std_input) - - if normalize_label == True: - mean_label = torch.mean(train_label, dim=0, keepdim=True) - std_label = torch.std(train_label, dim=0, keepdim=True) - train_label = normalize(train_label, mean_label, std_label) - test_label = normalize(test_label, mean_label, std_label) - - dataset = {} - dataset['train_input'] = train_input.to(device) - dataset['test_input'] = test_input.to(device) - - dataset['train_label'] = train_label.to(device) - dataset['test_label'] = test_label.to(device) - - return dataset - - - -def fit_params(x, y, fun, a_range=(-10,10), b_range=(-10,10), grid_number=101, iteration=3, verbose=True, device='cpu'): - ''' - fit a, b, c, d such that - - .. math:: - |y-(cf(ax+b)+d)|^2 - - is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model. - - Args: - ----- - x : 1D array - x values - y : 1D array - y values - fun : function - symbolic function - a_range : tuple - sweeping range of a - b_range : tuple - sweeping range of b - grid_num : int - number of steps along a and b - iteration : int - number of zooming in - verbose : bool - print extra information if True - device : str - device - - Returns: - -------- - a_best : float - best fitted a - b_best : float - best fitted b - c_best : float - best fitted c - d_best : float - best fitted d - r2_best : float - best r2 (coefficient of determination) - - Example - ------- - >>> num = 100 - >>> x = torch.linspace(-1,1,steps=num) - >>> noises = torch.normal(0,1,(num,)) * 0.02 - >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises - >>> fit_params(x, y, torch.sin) - r2 is 0.9999727010726929 - (tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000)) - ''' - # fit a, b, c, d such that y=c*fun(a*x+b)+d; both x and y are 1D array. - # sweep a and b, choose the best fitted model - for _ in range(iteration): - a_ = torch.linspace(a_range[0], a_range[1], steps=grid_number, device=device) - b_ = torch.linspace(b_range[0], b_range[1], steps=grid_number, device=device) - a_grid, b_grid = torch.meshgrid(a_, b_, indexing='ij') - post_fun = fun(a_grid[None,:,:] * x[:,None,None] + b_grid[None,:,:]) - x_mean = torch.mean(post_fun, dim=[0], keepdim=True) - y_mean = torch.mean(y, dim=[0], keepdim=True) - numerator = torch.sum((post_fun - x_mean)*(y-y_mean)[:,None,None], dim=0)**2 - denominator = torch.sum((post_fun - x_mean)**2, dim=0)*torch.sum((y - y_mean)[:,None,None]**2, dim=0) - r2 = numerator/(denominator+1e-4) - r2 = torch.nan_to_num(r2) - - - best_id = torch.argmax(r2) - a_id, b_id = torch.div(best_id, grid_number, rounding_mode='floor'), best_id % grid_number - - - if a_id == 0 or a_id == grid_number - 1 or b_id == 0 or b_id == grid_number - 1: - if _ == 0 and verbose==True: - print('Best value at boundary.') - if a_id == 0: - a_range = [a_[0], a_[1]] - if a_id == grid_number - 1: - a_range = [a_[-2], a_[-1]] - if b_id == 0: - b_range = [b_[0], b_[1]] - if b_id == grid_number - 1: - b_range = [b_[-2], b_[-1]] - - else: - a_range = [a_[a_id-1], a_[a_id+1]] - b_range = [b_[b_id-1], b_[b_id+1]] - - a_best = a_[a_id] - b_best = b_[b_id] - post_fun = fun(a_best * x + b_best) - r2_best = r2[a_id, b_id] - - if verbose == True: - print(f"r2 is {r2_best}") - if r2_best < 0.9: - print(f'r2 is not very high, please double check if you are choosing the correct symbolic function.') - - post_fun = torch.nan_to_num(post_fun) - reg = LinearRegression().fit(post_fun[:,None].detach().cpu().numpy(), y.detach().cpu().numpy()) - c_best = torch.from_numpy(reg.coef_)[0].to(device) - d_best = torch.from_numpy(np.array(reg.intercept_)).to(device) - return torch.stack([a_best, b_best, c_best, d_best]), r2_best - - -def sparse_mask(in_dim, out_dim): - ''' - get sparse mask - ''' - in_coord = torch.arange(in_dim) * 1/in_dim + 1/(2*in_dim) - out_coord = torch.arange(out_dim) * 1/out_dim + 1/(2*out_dim) - - dist_mat = torch.abs(out_coord[:,None] - in_coord[None,:]) - in_nearest = torch.argmin(dist_mat, dim=0) - in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1,0) - out_nearest = torch.argmin(dist_mat, dim=1) - out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1,0) - all_connection = torch.cat([in_connection, out_connection], dim=0) - mask = torch.zeros(in_dim, out_dim) - mask[all_connection[:,0], all_connection[:,1]] = 1. - - return mask - - -def add_symbolic(name, fun, c=1, fun_singularity=None): - ''' - add a symbolic function to library - - Args: - ----- - name : str - name of the function - fun : fun - torch function or lambda function - - Returns: - -------- - None - - Example - ------- - >>> print(SYMBOLIC_LIB['Bessel']) - KeyError: 'Bessel' - >>> add_symbolic('Bessel', torch.special.bessel_j0) - >>> print(SYMBOLIC_LIB['Bessel']) - (, Bessel) - ''' - exec(f"globals()['{name}'] = sympy.Function('{name}')") - if fun_singularity==None: - fun_singularity = fun - SYMBOLIC_LIB[name] = (fun, globals()[name], c, fun_singularity) - - -def ex_round(ex1, n_digit): - ''' - rounding the numbers in an expression to certain floating points - - Args: - ----- - ex1 : sympy expression - n_digit : int - - Returns: - -------- - ex2 : sympy expression - - Example - ------- - >>> from kan.utils import * - >>> from sympy import * - >>> input_vars = a, b = symbols('a b') - >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402 - >>> ex_round(expression, 2) - ''' - ex2 = ex1 - for a in sympy.preorder_traversal(ex1): - if isinstance(a, sympy.Float): - ex2 = ex2.subs(a, round(a, n_digit)) - return ex2 - - -def augment_input(orig_vars, aux_vars, x): - ''' - augment inputs - - Args: - ----- - orig_vars : list of sympy symbols - aux_vars : list of auxiliary symbols - x : inputs - - Returns: - -------- - augmented inputs - - Example - ------- - >>> from kan.utils import * - >>> from sympy import * - >>> orig_vars = a, b = symbols('a b') - >>> aux_vars = [a + b, a * b] - >>> x = torch.rand(100, 2) - >>> augment_input(orig_vars, aux_vars, x).shape - ''' - # if x is a tensor - if isinstance(x, torch.Tensor): - - aux_values = torch.tensor([]).to(x.device) - - for aux_var in aux_vars: - func = lambdify(orig_vars, aux_var,'numpy') # returns a numpy-ready function - aux_value = torch.from_numpy(func(*[x[:,[i]].numpy() for i in range(len(orig_vars))])) - aux_values = torch.cat([aux_values, aux_value], dim=1) - - x = torch.cat([aux_values, x], dim=1) - - # if x is a dataset - elif isinstance(x, dict): - x['train_input'] = augment_input(orig_vars, aux_vars, x['train_input']) - x['test_input'] = augment_input(orig_vars, aux_vars, x['test_input']) - - return x - - -def batch_jacobian(func, x, create_graph=False, mode='scalar'): - ''' - jacobian - - Args: - ----- - func : function or model - x : inputs - create_graph : bool - - Returns: - -------- - jacobian - - Example - ------- - >>> from kan.utils import batch_jacobian - >>> x = torch.normal(0,1,size=(100,2)) - >>> model = lambda x: x[:,[0]] + x[:,[1]] - >>> batch_jacobian(model, x) - ''' - # x in shape (Batch, Length) - def _func_sum(x): - return func(x).sum(dim=0) - if mode == 'scalar': - return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0] - elif mode == 'vector': - return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2) - -def batch_hessian(model, x, create_graph=False): - ''' - hessian - - Args: - ----- - func : function or model - x : inputs - create_graph : bool - - Returns: - -------- - jacobian - - Example - ------- - >>> from kan.utils import batch_hessian - >>> x = torch.normal(0,1,size=(100,2)) - >>> model = lambda x: x[:,[0]]**2 + x[:,[1]]**2 - >>> batch_hessian(model, x) - ''' - # x in shape (Batch, Length) - jac = lambda x: batch_jacobian(model, x, create_graph=True) - def _jac_sum(x): - return jac(x).sum(dim=0) - return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2) - - -def create_dataset_from_data(inputs, labels, train_ratio=0.8, device='cpu'): - ''' - create dataset from data - - Args: - ----- - inputs : 2D torch.float - labels : 2D torch.float - train_ratio : float - the ratio of training fraction - device : str - - Returns: - -------- - dataset (dictionary) - - Example - ------- - >>> from kan.utils import create_dataset_from_data - >>> x = torch.normal(0,1,size=(100,2)) - >>> y = torch.normal(0,1,size=(100,1)) - >>> dataset = create_dataset_from_data(x, y) - >>> dataset['train_input'].shape - ''' - num = inputs.shape[0] - train_id = np.random.choice(num, int(num*train_ratio), replace=False) - test_id = list(set(np.arange(num)) - set(train_id)) - dataset = {} - dataset['train_input'] = inputs[train_id].detach().to(device) - dataset['test_input'] = inputs[test_id].detach().to(device) - dataset['train_label'] = labels[train_id].detach().to(device) - dataset['test_label'] = labels[test_id].detach().to(device) - - return dataset - - -def get_derivative(model, inputs, labels, derivative='hessian', loss_mode='pred', reg_metric='w', lamb=0., lamb_l1=1., lamb_entropy=0.): - ''' - compute the jacobian/hessian of loss wrt to model parameters - - Args: - ----- - inputs : 2D torch.float - labels : 2D torch.float - derivative : str - 'jacobian' or 'hessian' - device : str - - Returns: - -------- - jacobian or hessian - ''' - def get_mapping(model): - - mapping = {} - name = 'model1' - - keys = list(model.state_dict().keys()) - for key in keys: - - y = re.findall(".[0-9]+", key) - if len(y) > 0: - y = y[0][1:] - x = re.split(".[0-9]+", key) - mapping[key] = name + '.' + x[0] + '[' + y + ']' + x[1] - - - y = re.findall("_[0-9]+", key) - if len(y) > 0: - y = y[0][1:] - x = re.split(".[0-9]+", key) - mapping[key] = name + '.' + x[0] + '[' + y + ']' - - return mapping - - - #model1 = copy.deepcopy(model) - model1 = model.copy() - mapping = get_mapping(model) - - # collect keys and shapes - keys = list(model.state_dict().keys()) - shapes = [] - - for params in model.parameters(): - shapes.append(params.shape) - - - # turn a flattened vector to model params - def param2statedict(p, keys, shapes): - - new_state_dict = {} - - start = 0 - n_group = len(keys) - for i in range(n_group): - shape = shapes[i] - n_params = torch.prod(torch.tensor(shape)) - new_state_dict[keys[i]] = p[start:start+n_params].reshape(shape) - start += n_params - - return new_state_dict - - def differentiable_load_state_dict(mapping, state_dict, model1): - - for key in keys: - if mapping[key][-1] != ']': - exec(f"del {mapping[key]}") - exec(f"{mapping[key]} = state_dict[key]") - - - # input: p, output: output - def get_param2loss_fun(inputs, labels): - - def param2loss_fun(p): - - p = p[0] - state_dict = param2statedict(p, keys, shapes) - # this step is non-differentiable - #model.load_state_dict(state_dict) - differentiable_load_state_dict(mapping, state_dict, model1) - if loss_mode == 'pred': - pred_loss = torch.mean((model1(inputs) - labels)**2, dim=(0,1), keepdim=True) - loss = pred_loss - elif loss_mode == 'reg': - reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1, lamb_entropy=lamb_entropy) * torch.ones(1,1) - loss = reg_loss - elif loss_mode == 'all': - pred_loss = torch.mean((model1(inputs) - labels)**2, dim=(0,1), keepdim=True) - reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1, lamb_entropy=lamb_entropy) * torch.ones(1,1) - loss = pred_loss + lamb * reg_loss - return loss - - return param2loss_fun - - fun = get_param2loss_fun(inputs, labels) - p = model2param(model)[None,:] - if derivative == 'hessian': - result = batch_hessian(fun, p) - elif derivative == 'jacobian': - result = batch_jacobian(fun, p) - return result - -def model2param(model): - ''' - turn model parameters into a flattened vector - ''' - p = torch.tensor([]).to(model.device) - for params in model.parameters(): - p = torch.cat([p, params.reshape(-1,)], dim=0) - return p diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/KANLayer.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/KANLayer.py deleted file mode 100644 index b880bfe8b58b69edd1b7d5dfe7234ec27d6af8ec..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/KANLayer.py +++ /dev/null @@ -1,364 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -from .spline import * -from .utils import sparse_mask - - -class KANLayer(nn.Module): - """ - KANLayer class - - - Attributes: - ----------- - in_dim: int - input dimension - out_dim: int - output dimension - num: int - the number of grid intervals - k: int - the piecewise polynomial order of splines - noise_scale: float - spline scale at initialization - coef: 2D torch.tensor - coefficients of B-spline bases - scale_base_mu: float - magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu - scale_base_sigma: float - magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma - scale_sp: float - mangitude of the spline function spline(x) - base_fun: fun - residual function b(x) - mask: 1D torch.float - mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function. - grid_eps: float in [0,1] - a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. - the id of activation functions that are locked - device: str - device - """ - - def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False): - '''' - initialize a KANLayer - - Args: - ----- - in_dim : int - input dimension. Default: 2. - out_dim : int - output dimension. Default: 3. - num : int - the number of grid intervals = G. Default: 5. - k : int - the order of piecewise polynomial. Default: 3. - noise_scale : float - the scale of noise injected at initialization. Default: 0.1. - scale_base_mu : float - the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - scale_base_sigma : float - the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - scale_sp : float - the scale of the base function spline(x). - base_fun : function - residual function b(x). Default: torch.nn.SiLU() - grid_eps : float - When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. - grid_range : list/np.array of shape (2,) - setting the range of grids. Default: [-1,1]. - sp_trainable : bool - If true, scale_sp is trainable - sb_trainable : bool - If true, scale_base is trainable - device : str - device - sparse_init : bool - if sparse_init = True, sparse initialization is applied. - - Returns: - -------- - self - - Example - ------- - >>> from kan.KANLayer import * - >>> model = KANLayer(in_dim=3, out_dim=5) - >>> (model.in_dim, model.out_dim) - ''' - super(KANLayer, self).__init__() - # size - self.out_dim = out_dim - self.in_dim = in_dim - self.num = num - self.k = k - - grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None,:].expand(self.in_dim, num+1) - grid = extend_grid(grid, k_extend=k) - self.grid = torch.nn.Parameter(grid).requires_grad_(False) - noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1/2) * noise_scale / num - - self.coef = torch.nn.Parameter(curve2coef(self.grid[:,k:-k].permute(1,0), noises, self.grid, k)) - - if sparse_init: - self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(False) - else: - self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(False) - - self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \ - scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable) - self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask).requires_grad_(sp_trainable) # make scale trainable - self.base_fun = base_fun - - - self.grid_eps = grid_eps - - self.to(device) - - def to(self, device): - super(KANLayer, self).to(device) - self.device = device - return self - - def forward(self, x): - ''' - KANLayer forward given input x - - Args: - ----- - x : 2D torch.float - inputs, shape (number of samples, input dimension) - - Returns: - -------- - y : 2D torch.float - outputs, shape (number of samples, output dimension) - preacts : 3D torch.float - fan out x into activations, shape (number of sampels, output dimension, input dimension) - postacts : 3D torch.float - the outputs of activation functions with preacts as inputs - postspline : 3D torch.float - the outputs of spline functions with preacts as inputs - - Example - ------- - >>> from kan.KANLayer import * - >>> model = KANLayer(in_dim=3, out_dim=5) - >>> x = torch.normal(0,1,size=(100,3)) - >>> y, preacts, postacts, postspline = model(x) - >>> y.shape, preacts.shape, postacts.shape, postspline.shape - ''' - batch = x.shape[0] - preacts = x[:,None,:].clone().expand(batch, self.out_dim, self.in_dim) - - base = self.base_fun(x) # (batch, in_dim) - y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k) - - postspline = y.clone().permute(0,2,1) - - y = self.scale_base[None,:,:] * base[:,:,None] + self.scale_sp[None,:,:] * y - y = self.mask[None,:,:] * y - - postacts = y.clone().permute(0,2,1) - - y = torch.sum(y, dim=1) - return y, preacts, postacts, postspline - - def update_grid_from_samples(self, x, mode='sample'): - ''' - update grid from samples - - Args: - ----- - x : 2D torch.float - inputs, shape (number of samples, input dimension) - - Returns: - -------- - None - - Example - ------- - >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) - >>> print(model.grid.data) - >>> x = torch.linspace(-3,3,steps=100)[:,None] - >>> model.update_grid_from_samples(x) - >>> print(model.grid.data) - ''' - - batch = x.shape[0] - #x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0) - x_pos = torch.sort(x, dim=0)[0] - y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) - num_interval = self.grid.shape[1] - 1 - 2*self.k - - def get_grid(num_interval): - ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] - grid_adaptive = x_pos[ids, :].permute(1,0) - margin = 0.00 - h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]] + 2 * margin)/num_interval - grid_uniform = grid_adaptive[:,[0]] - margin + h * torch.arange(num_interval+1,)[None, :].to(x.device) - grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive - return grid - - - grid = get_grid(num_interval) - - if mode == 'grid': - sample_grid = get_grid(2*num_interval) - x_pos = sample_grid.permute(1,0) - y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) - - self.grid.data = extend_grid(grid, k_extend=self.k) - #print('x_pos 2', x_pos.shape) - #print('y_eval 2', y_eval.shape) - self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) - - def initialize_grid_from_parent(self, parent, x, mode='sample'): - ''' - update grid from a parent KANLayer & samples - - Args: - ----- - parent : KANLayer - a parent KANLayer (whose grid is usually coarser than the current model) - x : 2D torch.float - inputs, shape (number of samples, input dimension) - - Returns: - -------- - None - - Example - ------- - >>> batch = 100 - >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) - >>> print(parent_model.grid.data) - >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3) - >>> x = torch.normal(0,1,size=(batch, 1)) - >>> model.initialize_grid_from_parent(parent_model, x) - >>> print(model.grid.data) - ''' - - batch = x.shape[0] - - # shrink grid - x_pos = torch.sort(x, dim=0)[0] - y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) - num_interval = self.grid.shape[1] - 1 - 2*self.k - - - ''' - # based on samples - def get_grid(num_interval): - ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] - grid_adaptive = x_pos[ids, :].permute(1,0) - h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval - grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device) - grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive - return grid''' - - #print('p', parent.grid) - # based on interpolating parent grid - def get_grid(num_interval): - x_pos = parent.grid[:,parent.k:-parent.k] - #print('x_pos', x_pos) - sp2 = KANLayer(in_dim=1, out_dim=self.in_dim,k=1,num=x_pos.shape[1]-1,scale_base_mu=0.0, scale_base_sigma=0.0).to(x.device) - - #print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim)) - #print('sp2_coef_shape', sp2.coef.shape) - sp2_coef = curve2coef(sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim), x_pos.permute(1,0).unsqueeze(dim=2), sp2.grid[:,:], k=1).permute(1,0,2) - shp = sp2_coef.shape - #sp2_coef = torch.cat([torch.zeros(shp[0], shp[1], 1), sp2_coef, torch.zeros(shp[0], shp[1], 1)], dim=2) - #print('sp2_coef',sp2_coef) - #print(sp2.coef.shape) - sp2.coef.data = sp2_coef - percentile = torch.linspace(-1,1,self.num+1).to(self.device) - grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1,0) - #print('c', grid) - return grid - - grid = get_grid(num_interval) - - if mode == 'grid': - sample_grid = get_grid(2*num_interval) - x_pos = sample_grid.permute(1,0) - y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) - - grid = extend_grid(grid, k_extend=self.k) - self.grid.data = grid - self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) - - def get_subset(self, in_id, out_id): - ''' - get a smaller KANLayer from a larger KANLayer (used for pruning) - - Args: - ----- - in_id : list - id of selected input neurons - out_id : list - id of selected output neurons - - Returns: - -------- - spb : KANLayer - - Example - ------- - >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3) - >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3]) - >>> kanlayer_small.in_dim, kanlayer_small.out_dim - (2, 3) - ''' - spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun) - spb.grid.data = self.grid[in_id] - spb.coef.data = self.coef[in_id][:,out_id] - spb.scale_base.data = self.scale_base[in_id][:,out_id] - spb.scale_sp.data = self.scale_sp[in_id][:,out_id] - spb.mask.data = self.mask[in_id][:,out_id] - - spb.in_dim = len(in_id) - spb.out_dim = len(out_id) - return spb - - - def swap(self, i1, i2, mode='in'): - ''' - swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out') - - Args: - ----- - i1 : int - i2 : int - mode : str - mode = 'in' or 'out' - - Returns: - -------- - None - - Example - ------- - >>> from kan.KANLayer import * - >>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3) - >>> print(model.coef) - >>> model.swap(0,1,mode='in') - >>> print(model.coef) - ''' - with torch.no_grad(): - def swap_(data, i1, i2, mode='in'): - if mode == 'in': - data[i1], data[i2] = data[i2].clone(), data[i1].clone() - elif mode == 'out': - data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone() - - if mode == 'in': - swap_(self.grid.data, i1, i2, mode='in') - swap_(self.coef.data, i1, i2, mode=mode) - swap_(self.scale_base.data, i1, i2, mode=mode) - swap_(self.scale_sp.data, i1, i2, mode=mode) - swap_(self.mask.data, i1, i2, mode=mode) - diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/LBFGS.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/LBFGS.py deleted file mode 100644 index 212477f23ec325ad80e4f2e849db8c895b380045..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/LBFGS.py +++ /dev/null @@ -1,493 +0,0 @@ -import torch -from functools import reduce -from torch.optim import Optimizer - -__all__ = ['LBFGS'] - -def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): - # ported from https://github.com/torch/optim/blob/master/polyinterp.lua - # Compute bounds of interpolation area - if bounds is not None: - xmin_bound, xmax_bound = bounds - else: - xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) - - # Code for most common case: cubic interpolation of 2 points - # w/ function and derivative values for both - # Solution in this case (where x2 is the farthest point): - # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); - # d2 = sqrt(d1^2 - g1*g2); - # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); - # t_new = min(max(min_pos,xmin_bound),xmax_bound); - d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) - d2_square = d1**2 - g1 * g2 - if d2_square >= 0: - d2 = d2_square.sqrt() - if x1 <= x2: - min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) - else: - min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) - return min(max(min_pos, xmin_bound), xmax_bound) - else: - return (xmin_bound + xmax_bound) / 2. - - -def _strong_wolfe(obj_func, - x, - t, - d, - f, - g, - gtd, - c1=1e-4, - c2=0.9, - tolerance_change=1e-9, - max_ls=25): - # ported from https://github.com/torch/optim/blob/master/lswolfe.lua - d_norm = d.abs().max() - g = g.clone(memory_format=torch.contiguous_format) - # evaluate objective and gradient using initial step - f_new, g_new = obj_func(x, t, d) - ls_func_evals = 1 - gtd_new = g_new.dot(d) - - # bracket an interval containing a point satisfying the Wolfe criteria - t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd - done = False - ls_iter = 0 - while ls_iter < max_ls: - # check conditions - #print(f_prev, f_new, g_new) - if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): - bracket = [t_prev, t] - bracket_f = [f_prev, f_new] - bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] - bracket_gtd = [gtd_prev, gtd_new] - break - - if abs(gtd_new) <= -c2 * gtd: - bracket = [t] - bracket_f = [f_new] - bracket_g = [g_new] - done = True - break - - if gtd_new >= 0: - bracket = [t_prev, t] - bracket_f = [f_prev, f_new] - bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)] - bracket_gtd = [gtd_prev, gtd_new] - break - - # interpolate - min_step = t + 0.01 * (t - t_prev) - max_step = t * 10 - tmp = t - t = _cubic_interpolate( - t_prev, - f_prev, - gtd_prev, - t, - f_new, - gtd_new, - bounds=(min_step, max_step)) - - # next step - t_prev = tmp - f_prev = f_new - g_prev = g_new.clone(memory_format=torch.contiguous_format) - gtd_prev = gtd_new - f_new, g_new = obj_func(x, t, d) - ls_func_evals += 1 - gtd_new = g_new.dot(d) - ls_iter += 1 - - - # reached max number of iterations? - if ls_iter == max_ls: - bracket = [0, t] - bracket_f = [f, f_new] - bracket_g = [g, g_new] - - # zoom phase: we now have a point satisfying the criteria, or - # a bracket around it. We refine the bracket until we find the - # exact point satisfying the criteria - insuf_progress = False - # find high and low points in bracket - low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) - while not done and ls_iter < max_ls: - # line-search bracket is so small - if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: - break - - # compute new trial value - t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0], - bracket[1], bracket_f[1], bracket_gtd[1]) - - # test that we are making sufficient progress: - # in case `t` is so close to boundary, we mark that we are making - # insufficient progress, and if - # + we have made insufficient progress in the last step, or - # + `t` is at one of the boundary, - # we will move `t` to a position which is `0.1 * len(bracket)` - # away from the nearest boundary point. - eps = 0.1 * (max(bracket) - min(bracket)) - if min(max(bracket) - t, t - min(bracket)) < eps: - # interpolation close to boundary - if insuf_progress or t >= max(bracket) or t <= min(bracket): - # evaluate at 0.1 away from boundary - if abs(t - max(bracket)) < abs(t - min(bracket)): - t = max(bracket) - eps - else: - t = min(bracket) + eps - insuf_progress = False - else: - insuf_progress = True - else: - insuf_progress = False - - # Evaluate new point - f_new, g_new = obj_func(x, t, d) - ls_func_evals += 1 - gtd_new = g_new.dot(d) - ls_iter += 1 - - if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: - # Armijo condition not satisfied or not lower than lowest point - bracket[high_pos] = t - bracket_f[high_pos] = f_new - bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) - bracket_gtd[high_pos] = gtd_new - low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) - else: - if abs(gtd_new) <= -c2 * gtd: - # Wolfe conditions satisfied - done = True - elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: - # old low becomes new high - bracket[high_pos] = bracket[low_pos] - bracket_f[high_pos] = bracket_f[low_pos] - bracket_g[high_pos] = bracket_g[low_pos] - bracket_gtd[high_pos] = bracket_gtd[low_pos] - - # new point becomes new low - bracket[low_pos] = t - bracket_f[low_pos] = f_new - bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) - bracket_gtd[low_pos] = gtd_new - - #print(bracket) - if len(bracket) == 1: - t = bracket[0] - f_new = bracket_f[0] - g_new = bracket_g[0] - else: - t = bracket[low_pos] - f_new = bracket_f[low_pos] - g_new = bracket_g[low_pos] - return f_new, g_new, t, ls_func_evals - - - -class LBFGS(Optimizer): - """Implements L-BFGS algorithm. - - Heavily inspired by `minFunc - `_. - - .. warning:: - This optimizer doesn't support per-parameter options and parameter - groups (there can be only one). - - .. warning:: - Right now all parameters have to be on a single device. This will be - improved in the future. - - .. note:: - This is a very memory intensive optimizer (it requires additional - ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory - try reducing the history size, or use a different algorithm. - - Args: - lr (float): learning rate (default: 1) - max_iter (int): maximal number of iterations per optimization step - (default: 20) - max_eval (int): maximal number of function evaluations per optimization - step (default: max_iter * 1.25). - tolerance_grad (float): termination tolerance on first order optimality - (default: 1e-7). - tolerance_change (float): termination tolerance on function - value/parameter changes (default: 1e-9). - history_size (int): update history size (default: 100). - line_search_fn (str): either 'strong_wolfe' or None (default: None). - """ - - def __init__(self, - params, - lr=1, - max_iter=20, - max_eval=None, - tolerance_grad=1e-7, - tolerance_change=1e-9, - tolerance_ys=1e-32, - history_size=100, - line_search_fn=None): - if max_eval is None: - max_eval = max_iter * 5 // 4 - defaults = dict( - lr=lr, - max_iter=max_iter, - max_eval=max_eval, - tolerance_grad=tolerance_grad, - tolerance_change=tolerance_change, - tolerance_ys=tolerance_ys, - history_size=history_size, - line_search_fn=line_search_fn) - super().__init__(params, defaults) - - if len(self.param_groups) != 1: - raise ValueError("LBFGS doesn't support per-parameter options " - "(parameter groups)") - - self._params = self.param_groups[0]['params'] - self._numel_cache = None - - def _numel(self): - if self._numel_cache is None: - self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) - return self._numel_cache - - def _gather_flat_grad(self): - views = [] - for p in self._params: - if p.grad is None: - view = p.new(p.numel()).zero_() - elif p.grad.is_sparse: - view = p.grad.to_dense().view(-1) - else: - view = p.grad.view(-1) - views.append(view) - device = views[0].device - return torch.cat(views, dim=0) - - def _add_grad(self, step_size, update): - offset = 0 - for p in self._params: - numel = p.numel() - # view as to avoid deprecated pointwise semantics - p.add_(update[offset:offset + numel].view_as(p), alpha=step_size) - offset += numel - assert offset == self._numel() - - def _clone_param(self): - return [p.clone(memory_format=torch.contiguous_format) for p in self._params] - - def _set_param(self, params_data): - for p, pdata in zip(self._params, params_data): - p.copy_(pdata) - - def _directional_evaluate(self, closure, x, t, d): - self._add_grad(t, d) - loss = float(closure()) - flat_grad = self._gather_flat_grad() - self._set_param(x) - return loss, flat_grad - - - @torch.no_grad() - def step(self, closure): - """Perform a single optimization step. - - Args: - closure (Callable): A closure that reevaluates the model - and returns the loss. - """ - - torch.manual_seed(0) - - assert len(self.param_groups) == 1 - - # Make sure the closure is always called with grad enabled - closure = torch.enable_grad()(closure) - - group = self.param_groups[0] - lr = group['lr'] - max_iter = group['max_iter'] - max_eval = group['max_eval'] - tolerance_grad = group['tolerance_grad'] - tolerance_change = group['tolerance_change'] - tolerance_ys = group['tolerance_ys'] - line_search_fn = group['line_search_fn'] - history_size = group['history_size'] - - # NOTE: LBFGS has only global state, but we register it as state for - # the first param, because this helps with casting in load_state_dict - state = self.state[self._params[0]] - state.setdefault('func_evals', 0) - state.setdefault('n_iter', 0) - - # evaluate initial f(x) and df/dx - orig_loss = closure() - loss = float(orig_loss) - current_evals = 1 - state['func_evals'] += 1 - - flat_grad = self._gather_flat_grad() - opt_cond = flat_grad.abs().max() <= tolerance_grad - - # optimal condition - if opt_cond: - return orig_loss - - # tensors cached in state (for tracing) - d = state.get('d') - t = state.get('t') - old_dirs = state.get('old_dirs') - old_stps = state.get('old_stps') - ro = state.get('ro') - H_diag = state.get('H_diag') - prev_flat_grad = state.get('prev_flat_grad') - prev_loss = state.get('prev_loss') - - n_iter = 0 - # optimize for a max of max_iter iterations - while n_iter < max_iter: - # keep track of nb of iterations - n_iter += 1 - state['n_iter'] += 1 - - ############################################################ - # compute gradient descent direction - ############################################################ - if state['n_iter'] == 1: - d = flat_grad.neg() - old_dirs = [] - old_stps = [] - ro = [] - H_diag = 1 - else: - # do lbfgs update (update memory) - y = flat_grad.sub(prev_flat_grad) - s = d.mul(t) - ys = y.dot(s) # y*s - if ys > tolerance_ys: - # updating memory - if len(old_dirs) == history_size: - # shift history by one (limited-memory) - old_dirs.pop(0) - old_stps.pop(0) - ro.pop(0) - - # store new direction/step - old_dirs.append(y) - old_stps.append(s) - ro.append(1. / ys) - - # update scale of initial Hessian approximation - H_diag = ys / y.dot(y) # (y*y) - - # compute the approximate (L-BFGS) inverse Hessian - # multiplied by the gradient - num_old = len(old_dirs) - - if 'al' not in state: - state['al'] = [None] * history_size - al = state['al'] - - # iteration in L-BFGS loop collapsed to use just one buffer - q = flat_grad.neg() - for i in range(num_old - 1, -1, -1): - al[i] = old_stps[i].dot(q) * ro[i] - q.add_(old_dirs[i], alpha=-al[i]) - - # multiply by initial Hessian - # r/d is the final direction - d = r = torch.mul(q, H_diag) - for i in range(num_old): - be_i = old_dirs[i].dot(r) * ro[i] - r.add_(old_stps[i], alpha=al[i] - be_i) - - if prev_flat_grad is None: - prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) - else: - prev_flat_grad.copy_(flat_grad) - prev_loss = loss - - ############################################################ - # compute step length - ############################################################ - # reset initial guess for step size - if state['n_iter'] == 1: - t = min(1., 1. / flat_grad.abs().sum()) * lr - else: - t = lr - - # directional derivative - gtd = flat_grad.dot(d) # g * d - - # directional derivative is below tolerance - if gtd > -tolerance_change: - break - - # optional line search: user function - ls_func_evals = 0 - if line_search_fn is not None: - # perform line search, using user function - if line_search_fn != "strong_wolfe": - raise RuntimeError("only 'strong_wolfe' is supported") - else: - x_init = self._clone_param() - - def obj_func(x, t, d): - return self._directional_evaluate(closure, x, t, d) - loss, flat_grad, t, ls_func_evals = _strong_wolfe( - obj_func, x_init, t, d, loss, flat_grad, gtd) - self._add_grad(t, d) - opt_cond = flat_grad.abs().max() <= tolerance_grad - else: - # no line search, simply move with fixed-step - self._add_grad(t, d) - if n_iter != max_iter: - # re-evaluate function only if not in last iteration - # the reason we do this: in a stochastic setting, - # no use to re-evaluate that function here - with torch.enable_grad(): - loss = float(closure()) - flat_grad = self._gather_flat_grad() - opt_cond = flat_grad.abs().max() <= tolerance_grad - ls_func_evals = 1 - - # update func eval - current_evals += ls_func_evals - state['func_evals'] += ls_func_evals - - ############################################################ - # check conditions - ############################################################ - if n_iter == max_iter: - break - - if current_evals >= max_eval: - break - - # optimal condition - if opt_cond: - break - - # lack of progress - if d.mul(t).abs().max() <= tolerance_change: - break - - if abs(loss - prev_loss) < tolerance_change: - break - - state['d'] = d - state['t'] = t - state['old_dirs'] = old_dirs - state['old_stps'] = old_stps - state['ro'] = ro - state['H_diag'] = H_diag - state['prev_flat_grad'] = prev_flat_grad - state['prev_loss'] = prev_loss - - return orig_loss diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/MLP.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/MLP.py deleted file mode 100644 index 1066c3b3db20684b86c6fc2794c8e1c0330b3967..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/MLP.py +++ /dev/null @@ -1,361 +0,0 @@ -import torch -import torch.nn as nn -import matplotlib.pyplot as plt -import numpy as np -from tqdm import tqdm -from .LBFGS import LBFGS - -seed = 0 -torch.manual_seed(seed) - -class MLP(nn.Module): - - def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'): - super(MLP, self).__init__() - - torch.manual_seed(seed) - - linears = [] - self.width = width - self.depth = depth = len(width) - 1 - for i in range(depth): - linears.append(nn.Linear(width[i], width[i+1])) - self.linears = nn.ModuleList(linears) - - #if activation == 'silu': - self.act_fun = torch.nn.SiLU() - self.save_act = save_act - self.acts = None - - self.cache_data = None - - self.device = device - self.to(device) - - - def to(self, device): - super(MLP, self).to(device) - self.device = device - - return self - - - def get_act(self, x=None): - if isinstance(x, dict): - x = x['train_input'] - if x == None: - if self.cache_data != None: - x = self.cache_data - else: - raise Exception("missing input data x") - save_act = self.save_act - self.save_act = True - self.forward(x) - self.save_act = save_act - - @property - def w(self): - return [self.linears[l].weight for l in range(self.depth)] - - def forward(self, x): - - # cache data - self.cache_data = x - - self.acts = [] - self.acts_scale = [] - self.wa_forward = [] - self.a_forward = [] - - for i in range(self.depth): - - if self.save_act: - act = x.clone() - act_scale = torch.std(x, dim=0) - wa_forward = act_scale[None, :] * self.linears[i].weight - self.acts.append(act) - if i > 0: - self.acts_scale.append(act_scale) - self.wa_forward.append(wa_forward) - - x = self.linears[i](x) - if i < self.depth - 1: - x = self.act_fun(x) - else: - if self.save_act: - act_scale = torch.std(x, dim=0) - self.acts_scale.append(act_scale) - - return x - - def attribute(self): - if self.acts == None: - self.get_act() - - node_scores = [] - edge_scores = [] - - # back propagate from the last layer - node_score = torch.ones(self.width[-1]).requires_grad_(True).to(self.device) - node_scores.append(node_score) - - for l in range(self.depth,0,-1): - - edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l-1]), node_score/(self.acts_scale[l-1]+1e-4)) - edge_scores.append(edge_score) - - # this might be improper for MLPs (although reasonable for KANs) - node_score = torch.sum(edge_score, dim=0)/torch.sqrt(torch.tensor(self.width[l-1], device=self.device)) - #print(self.width[l]) - node_scores.append(node_score) - - self.node_scores = list(reversed(node_scores)) - self.edge_scores = list(reversed(edge_scores)) - self.wa_backward = self.edge_scores - - def plot(self, beta=3, scale=1., metric='w'): - # metric = 'w', 'act' or 'fa' - - if metric == 'fa': - self.attribute() - - depth = self.depth - y0 = 0.5 - fig, ax = plt.subplots(figsize=(3*scale,3*y0*depth*scale)) - shp = self.width - - min_spacing = 1/max(self.width) - for j in range(len(shp)): - N = shp[j] - for i in range(N): - plt.scatter(1 / (2 * N) + i / N, j * y0, s=min_spacing ** 2 * 5000 * scale ** 2, color='black') - - plt.ylim(-0.1*y0,y0*depth+0.1*y0) - plt.xlim(-0.02,1.02) - - linears = self.linears - - for ii in range(len(linears)): - linear = linears[ii] - p = linear.weight - p_shp = p.shape - - if metric == 'w': - pass - elif metric == 'act': - p = self.wa_forward[ii] - elif metric == 'fa': - p = self.wa_backward[ii] - else: - raise Exception('metric = \'{}\' not recognized. Choices are \'w\', \'act\', \'fa\'.'.format(metric)) - for i in range(p_shp[0]): - for j in range(p_shp[1]): - plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*p_shp[1])+j/p_shp[1]], [y0*(ii+1),y0*ii], lw=0.5*scale, alpha=np.tanh(beta*np.abs(p[i,j].cpu().detach().numpy())), color="blue" if p[i,j]>0 else "red") - - ax.axis('off') - - def reg(self, reg_metric, lamb_l1, lamb_entropy): - - if reg_metric == 'w': - acts_scale = self.w - if reg_metric == 'act': - acts_scale = self.wa_forward - if reg_metric == 'fa': - acts_scale = self.wa_backward - if reg_metric == 'a': - acts_scale = self.acts_scale - - if len(acts_scale[0].shape) == 2: - reg_ = 0. - - for i in range(len(acts_scale)): - vec = acts_scale[i] - vec = torch.abs(vec) - - l1 = torch.sum(vec) - p_row = vec / (torch.sum(vec, dim=1, keepdim=True) + 1) - p_col = vec / (torch.sum(vec, dim=0, keepdim=True) + 1) - entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1)) - entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0)) - reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col) - - elif len(acts_scale[0].shape) == 1: - - reg_ = 0. - - for i in range(len(acts_scale)): - vec = acts_scale[i] - vec = torch.abs(vec) - - l1 = torch.sum(vec) - p = vec / (torch.sum(vec) + 1) - entropy = - torch.sum(p * torch.log2(p + 1e-4)) - reg_ += lamb_l1 * l1 + lamb_entropy * entropy - - return reg_ - - def get_reg(self, reg_metric, lamb_l1, lamb_entropy): - return self.reg(reg_metric, lamb_l1, lamb_entropy) - - def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1., batch=-1, - metrics=None, in_vars=None, out_vars=None, beta=3, device='cpu', reg_metric='w', display_metrics=None): - - if lamb > 0. and not self.save_act: - print('setting lamb=0. If you want to set lamb > 0, set =True') - - old_save_act = self.save_act - if lamb == 0.: - self.save_act = False - - pbar = tqdm(range(steps), desc='description', ncols=100) - - if loss_fn == None: - loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2) - else: - loss_fn = loss_fn_eval = loss_fn - - if opt == "Adam": - optimizer = torch.optim.Adam(self.parameters(), lr=lr) - elif opt == "LBFGS": - optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) - - results = {} - results['train_loss'] = [] - results['test_loss'] = [] - results['reg'] = [] - if metrics != None: - for i in range(len(metrics)): - results[metrics[i].__name__] = [] - - if batch == -1 or batch > dataset['train_input'].shape[0]: - batch_size = dataset['train_input'].shape[0] - batch_size_test = dataset['test_input'].shape[0] - else: - batch_size = batch - batch_size_test = batch - - global train_loss, reg_ - - def closure(): - global train_loss, reg_ - optimizer.zero_grad() - pred = self.forward(dataset['train_input'][train_id].to(self.device)) - train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device)) - if self.save_act: - if reg_metric == 'fa': - self.attribute() - reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy) - else: - reg_ = torch.tensor(0.) - objective = train_loss + lamb * reg_ - objective.backward() - return objective - - for _ in pbar: - - if _ == steps-1 and old_save_act: - self.save_act = True - - train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False) - test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False) - - if opt == "LBFGS": - optimizer.step(closure) - - if opt == "Adam": - pred = self.forward(dataset['train_input'][train_id].to(self.device)) - train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device)) - if self.save_act: - reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy) - else: - reg_ = torch.tensor(0.) - loss = train_loss + lamb * reg_ - optimizer.zero_grad() - loss.backward() - optimizer.step() - - test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)), dataset['test_label'][test_id].to(self.device)) - - - if metrics != None: - for i in range(len(metrics)): - results[metrics[i].__name__].append(metrics[i]().item()) - - results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy()) - results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy()) - results['reg'].append(reg_.cpu().detach().numpy()) - - if _ % log == 0: - if display_metrics == None: - pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy())) - else: - string = '' - data = () - for metric in display_metrics: - string += f' {metric}: %.2e |' - try: - results[metric] - except: - raise Exception(f'{metric} not recognized') - data += (results[metric][-1],) - pbar.set_description(string % data) - - return results - - @property - def connection_cost(self): - - with torch.no_grad(): - cc = 0. - for linear in self.linears: - t = torch.abs(linear.weight) - def get_coordinate(n): - return torch.linspace(0,1,steps=n+1, device=self.device)[:n] + 1/(2*n) - - in_dim = t.shape[0] - x_in = get_coordinate(in_dim) - - out_dim = t.shape[1] - x_out = get_coordinate(out_dim) - - dist = torch.abs(x_in[:,None] - x_out[None,:]) - cc += torch.sum(dist * t) - - return cc - - def swap(self, l, i1, i2): - - def swap_row(data, i1, i2): - data[i1], data[i2] = data[i2].clone(), data[i1].clone() - - def swap_col(data, i1, i2): - data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone() - - swap_row(self.linears[l-1].weight.data, i1, i2) - swap_row(self.linears[l-1].bias.data, i1, i2) - swap_col(self.linears[l].weight.data, i1, i2) - - def auto_swap_l(self, l): - - num = self.width[l] - for i in range(num): - ccs = [] - for j in range(num): - self.swap(l,i,j) - self.get_act() - self.attribute() - cc = self.connection_cost.detach().clone() - ccs.append(cc) - self.swap(l,i,j) - j = torch.argmin(torch.tensor(ccs)) - self.swap(l,i,j) - - def auto_swap(self): - depth = self.depth - for l in range(1, depth): - self.auto_swap_l(l) - - def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False): - if x == None: - x = self.cache_data - plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test, verbose=verbose) \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/MultKAN.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/MultKAN.py deleted file mode 100644 index 37f3e58200586b22606f3f15dd1f99f606587568..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/MultKAN.py +++ /dev/null @@ -1,2805 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -from .KANLayer import KANLayer -#from .Symbolic_MultKANLayer import * -from .Symbolic_KANLayer import Symbolic_KANLayer -from .LBFGS import * -import os -import glob -import matplotlib.pyplot as plt -from tqdm import tqdm -import random -import copy -#from .MultKANLayer import MultKANLayer -import pandas as pd -from sympy.printing import latex -from sympy import * -import sympy -import yaml -from .spline import curve2coef -from .utils import SYMBOLIC_LIB -from .hypothesis import plot_tree - -class MultKAN(nn.Module): - ''' - KAN class - - Attributes: - ----------- - grid : int - the number of grid intervals - k : int - spline order - act_fun : a list of KANLayers - symbolic_fun: a list of Symbolic_KANLayer - depth : int - depth of KAN - width : list - number of neurons in each layer. - Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons. - With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2). - mult_arity : int, or list of int lists - multiplication arity for each multiplication node (the number of numbers to be multiplied) - grid : int - the number of grid intervals - k : int - the order of piecewise polynomial - base_fun : fun - residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x) - symbolic_fun : a list of Symbolic_KANLayer - Symbolic_KANLayers - symbolic_enabled : bool - If False, the symbolic front is not computed (to save time). Default: True. - width_in : list - The number of input neurons for each layer - width_out : list - The number of output neurons for each layer - base_fun_name : str - The base function b(x) - grip_eps : float - The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile) - node_bias : a list of 1D torch.float - node_scale : a list of 1D torch.float - subnode_bias : a list of 1D torch.float - subnode_scale : a list of 1D torch.float - symbolic_enabled : bool - when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero) - affine_trainable : bool - indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale) - sp_trainable : bool - indicate whether the overall magnitude of splines is trainable - sb_trainable : bool - indicate whether the overall magnitude of base function is trainable - save_act : bool - indicate whether intermediate activations are saved in forward pass - node_scores : None or list of 1D torch.float - node attribution score - edge_scores : None or list of 2D torch.float - edge attribution score - subnode_scores : None or list of 1D torch.float - subnode attribution score - cache_data : None or 2D torch.float - cached input data - acts : None or a list of 2D torch.float - activations on nodes - auto_save : bool - indicate whether to automatically save a checkpoint once the model is modified - state_id : int - the state of the model (used to save checkpoint) - ckpt_path : str - the folder to store checkpoints - round : int - the number of times rewind() has been called - device : str - ''' - def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'): - ''' - initalize a KAN model - - Args: - ----- - width : list of int - Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs) - With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs) - grid : int - number of grid intervals. Default: 3. - k : int - order of piecewise polynomial. Default: 3. - mult_arity : int, or list of int lists - multiplication arity for each multiplication node (the number of numbers to be multiplied) - noise_scale : float - initial injected noise to spline. - base_fun : str - the residual function b(x). Default: 'silu' - symbolic_enabled : bool - compute (True) or skip (False) symbolic computations (for efficiency). By default: True. - affine_trainable : bool - affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias - grid_eps : float - When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. - grid_range : list/np.array of shape (2,)) - setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True) - sp_trainable : bool - If true, scale_sp is trainable. Default: True. - sb_trainable : bool - If true, scale_base is trainable. Default: True. - device : str - device - seed : int - random seed - save_act : bool - indicate whether intermediate activations are saved in forward pass - sparse_init : bool - sparse initialization (True) or normal dense initialization. Default: False. - auto_save : bool - indicate whether to automatically save a checkpoint once the model is modified - state_id : int - the state of the model (used to save checkpoint) - ckpt_path : str - the folder to store checkpoints. Default: './model' - round : int - the number of times rewind() has been called - device : str - - Returns: - -------- - self - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - checkpoint directory created: ./model - saving model version 0.0 - ''' - super(MultKAN, self).__init__() - - torch.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) - - ### initializeing the numerical front ### - - self.act_fun = [] - self.depth = len(width) - 1 - - #print('haha1', width) - for i in range(len(width)): - #print(type(width[i]), type(width[i]) == int) - if type(width[i]) == int or type(width[i]) == np.int64: - width[i] = [width[i],0] - - #print('haha2', width) - - self.width = width - - # if mult_arity is just a scalar, we extend it to a list of lists - # e.g, mult_arity = [[2,3],[4]] means that in the first hidden layer, 2 mult ops have arity 2 and 3, respectively; - # in the second hidden layer, 1 mult op has arity 4. - if isinstance(mult_arity, int): - self.mult_homo = True # when homo is True, parallelization is possible - else: - self.mult_homo = False # when home if False, for loop is required. - self.mult_arity = mult_arity - - width_in = self.width_in - width_out = self.width_out - - self.base_fun_name = base_fun - if base_fun == 'silu': - base_fun = torch.nn.SiLU() - elif base_fun == 'identity': - base_fun = torch.nn.Identity() - elif base_fun == 'zero': - base_fun = lambda x: x*0. - - self.grid_eps = grid_eps - self.grid_range = grid_range - - - for l in range(self.depth): - # splines - if isinstance(grid, list): - grid_l = grid[l] - else: - grid_l = grid - - if isinstance(k, list): - k_l = k[l] - else: - k_l = k - - - sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid_l, k=k_l, noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init) - self.act_fun.append(sp_batch) - - self.node_bias = [] - self.node_scale = [] - self.subnode_bias = [] - self.subnode_scale = [] - - globals()['self.node_bias_0'] = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False) - exec('self.node_bias_0' + " = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)") - - for l in range(self.depth): - exec(f'self.node_bias_{l} = torch.nn.Parameter(torch.zeros(width_in[l+1])).requires_grad_(affine_trainable)') - exec(f'self.node_scale_{l} = torch.nn.Parameter(torch.ones(width_in[l+1])).requires_grad_(affine_trainable)') - exec(f'self.subnode_bias_{l} = torch.nn.Parameter(torch.zeros(width_out[l+1])).requires_grad_(affine_trainable)') - exec(f'self.subnode_scale_{l} = torch.nn.Parameter(torch.ones(width_out[l+1])).requires_grad_(affine_trainable)') - exec(f'self.node_bias.append(self.node_bias_{l})') - exec(f'self.node_scale.append(self.node_scale_{l})') - exec(f'self.subnode_bias.append(self.subnode_bias_{l})') - exec(f'self.subnode_scale.append(self.subnode_scale_{l})') - - - self.act_fun = nn.ModuleList(self.act_fun) - - self.grid = grid - self.k = k - self.base_fun = base_fun - - ### initializing the symbolic front ### - self.symbolic_fun = [] - for l in range(self.depth): - sb_batch = Symbolic_KANLayer(in_dim=width_in[l], out_dim=width_out[l+1]) - self.symbolic_fun.append(sb_batch) - - self.symbolic_fun = nn.ModuleList(self.symbolic_fun) - self.symbolic_enabled = symbolic_enabled - self.affine_trainable = affine_trainable - self.sp_trainable = sp_trainable - self.sb_trainable = sb_trainable - - self.save_act = save_act - - self.node_scores = None - self.edge_scores = None - self.subnode_scores = None - - self.cache_data = None - self.acts = None - - self.auto_save = auto_save - self.state_id = 0 - self.ckpt_path = ckpt_path - self.round = round - - self.device = device - self.to(device) - - if auto_save: - if first_init: - if not os.path.exists(ckpt_path): - # Create the directory - os.makedirs(ckpt_path) - print(f"checkpoint directory created: {ckpt_path}") - print('saving model version 0.0') - - history_path = self.ckpt_path+'/history.txt' - with open(history_path, 'w') as file: - file.write(f'### Round {self.round} ###' + '\n') - file.write('init => 0.0' + '\n') - self.saveckpt(path=self.ckpt_path+'/'+'0.0') - else: - self.state_id = state_id - - self.input_id = torch.arange(self.width_in[0],) - - def to(self, device): - ''' - move the model to device - - Args: - ----- - device : str or device - - Returns: - -------- - self - - Example - ------- - >>> from kan import * - >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - >>> model.to(device) - ''' - super(MultKAN, self).to(device) - self.device = device - - for kanlayer in self.act_fun: - kanlayer.to(device) - - for symbolic_kanlayer in self.symbolic_fun: - symbolic_kanlayer.to(device) - - return self - - @property - def width_in(self): - ''' - The number of input nodes for each layer - ''' - width = self.width - width_in = [width[l][0]+width[l][1] for l in range(len(width))] - return width_in - - @property - def width_out(self): - ''' - The number of output subnodes for each layer - ''' - width = self.width - if self.mult_homo == True: - width_out = [width[l][0]+self.mult_arity*width[l][1] for l in range(len(width))] - else: - width_out = [width[l][0]+int(np.sum(self.mult_arity[l])) for l in range(len(width))] - return width_out - - @property - def n_sum(self): - ''' - The number of addition nodes for each layer - ''' - width = self.width - n_sum = [width[l][0] for l in range(1,len(width)-1)] - return n_sum - - @property - def n_mult(self): - ''' - The number of multiplication nodes for each layer - ''' - width = self.width - n_mult = [width[l][1] for l in range(1,len(width)-1)] - return n_mult - - @property - def feature_score(self): - ''' - attribution scores for inputs - ''' - self.attribute() - if self.node_scores == None: - return None - else: - return self.node_scores[0] - - def initialize_from_another_model(self, another_model, x): - ''' - initialize from another model of the same width, but their 'grid' parameter can be different. - Note this is equivalent to refine() when we don't want to keep another_model - - Args: - ----- - another_model : MultKAN - x : 2D torch.float - - Returns: - -------- - self - - Example - ------- - >>> from kan import * - >>> model1 = KAN(width=[2,5,1], grid=3) - >>> model2 = KAN(width=[2,5,1], grid=10) - >>> x = torch.rand(100,2) - >>> model2.initialize_from_another_model(model1, x) - ''' - another_model(x) # get activations - batch = x.shape[0] - - self.initialize_grid_from_another_model(another_model, x) - - for l in range(self.depth): - spb = self.act_fun[l] - #spb_parent = another_model.act_fun[l] - - # spb = spb_parent - preacts = another_model.spline_preacts[l] - postsplines = another_model.spline_postsplines[l] - self.act_fun[l].coef.data = curve2coef(preacts[:,0,:], postsplines.permute(0,2,1), spb.grid, k=spb.k) - self.act_fun[l].scale_base.data = another_model.act_fun[l].scale_base.data - self.act_fun[l].scale_sp.data = another_model.act_fun[l].scale_sp.data - self.act_fun[l].mask.data = another_model.act_fun[l].mask.data - - for l in range(self.depth): - self.node_bias[l].data = another_model.node_bias[l].data - self.node_scale[l].data = another_model.node_scale[l].data - - self.subnode_bias[l].data = another_model.subnode_bias[l].data - self.subnode_scale[l].data = another_model.subnode_scale[l].data - - for l in range(self.depth): - self.symbolic_fun[l] = another_model.symbolic_fun[l] - - return self.to(self.device) - - def log_history(self, method_name): - - if self.auto_save: - - # save to log file - #print(func.__name__) - with open(self.ckpt_path+'/history.txt', 'a') as file: - file.write(str(self.round)+'.'+str(self.state_id)+' => '+ method_name + ' => ' + str(self.round)+'.'+str(self.state_id+1) + '\n') - - # update state_id - self.state_id += 1 - - # save to ckpt - self.saveckpt(path=self.ckpt_path+'/'+str(self.round)+'.'+str(self.state_id)) - print('saving model version '+str(self.round)+'.'+str(self.state_id)) - - - def refine(self, new_grid): - ''' - grid refinement - - Args: - ----- - new_grid : init - the number of grid intervals after refinement - - Returns: - -------- - a refined model : MultKAN - - Example - ------- - >>> from kan import * - >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - >>> print(model.grid) - >>> x = torch.rand(100,2) - >>> model.get_act(x) - >>> model = model.refine(10) - >>> print(model.grid) - checkpoint directory created: ./model - saving model version 0.0 - 5 - saving model version 0.1 - 10 - ''' - - model_new = MultKAN(width=self.width, - grid=new_grid, - k=self.k, - mult_arity=self.mult_arity, - base_fun=self.base_fun_name, - symbolic_enabled=self.symbolic_enabled, - affine_trainable=self.affine_trainable, - grid_eps=self.grid_eps, - grid_range=self.grid_range, - sp_trainable=self.sp_trainable, - sb_trainable=self.sb_trainable, - ckpt_path=self.ckpt_path, - auto_save=True, - first_init=False, - state_id=self.state_id, - round=self.round, - device=self.device) - - model_new.initialize_from_another_model(self, self.cache_data) - model_new.cache_data = self.cache_data - model_new.grid = new_grid - - self.log_history('refine') - model_new.state_id += 1 - - return model_new.to(self.device) - - - def saveckpt(self, path='model'): - ''' - save the current model to files (configuration file and state file) - - Args: - ----- - path : str - the path where checkpoints are saved - - Returns: - -------- - None - - Example - ------- - >>> from kan import * - >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - >>> model.saveckpt('./mark') - # There will be three files appearing in the current folder: mark_cache_data, mark_config.yml, mark_state - ''' - - model = self - - dic = dict( - width = model.width, - grid = model.grid, - k = model.k, - mult_arity = model.mult_arity, - base_fun_name = model.base_fun_name, - symbolic_enabled = model.symbolic_enabled, - affine_trainable = model.affine_trainable, - grid_eps = model.grid_eps, - grid_range = model.grid_range, - sp_trainable = model.sp_trainable, - sb_trainable = model.sb_trainable, - state_id = model.state_id, - auto_save = model.auto_save, - ckpt_path = model.ckpt_path, - round = model.round, - device = str(model.device) - ) - - for i in range (model.depth): - dic[f'symbolic.funs_name.{i}'] = model.symbolic_fun[i].funs_name - - with open(f'{path}_config.yml', 'w') as outfile: - yaml.dump(dic, outfile, default_flow_style=False) - - torch.save(model.state_dict(), f'{path}_state') - torch.save(model.cache_data, f'{path}_cache_data') - - @staticmethod - def loadckpt(path='model'): - ''' - load checkpoint from path - - Args: - ----- - path : str - the path where checkpoints are saved - - Returns: - -------- - MultKAN - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - >>> model.saveckpt('./mark') - >>> KAN.loadckpt('./mark') - ''' - with open(f'{path}_config.yml', 'r') as stream: - config = yaml.safe_load(stream) - - state = torch.load(f'{path}_state') - - model_load = MultKAN(width=config['width'], - grid=config['grid'], - k=config['k'], - mult_arity = config['mult_arity'], - base_fun=config['base_fun_name'], - symbolic_enabled=config['symbolic_enabled'], - affine_trainable=config['affine_trainable'], - grid_eps=config['grid_eps'], - grid_range=config['grid_range'], - sp_trainable=config['sp_trainable'], - sb_trainable=config['sb_trainable'], - state_id=config['state_id'], - auto_save=config['auto_save'], - first_init=False, - ckpt_path=config['ckpt_path'], - round = config['round']+1, - device = config['device']) - - model_load.load_state_dict(state) - model_load.cache_data = torch.load(f'{path}_cache_data') - - depth = len(model_load.width) - 1 - for l in range(depth): - out_dim = model_load.symbolic_fun[l].out_dim - in_dim = model_load.symbolic_fun[l].in_dim - funs_name = config[f'symbolic.funs_name.{l}'] - for j in range(out_dim): - for i in range(in_dim): - fun_name = funs_name[j][i] - model_load.symbolic_fun[l].funs_name[j][i] = fun_name - model_load.symbolic_fun[l].funs[j][i] = SYMBOLIC_LIB[fun_name][0] - model_load.symbolic_fun[l].funs_sympy[j][i] = SYMBOLIC_LIB[fun_name][1] - model_load.symbolic_fun[l].funs_avoid_singularity[j][i] = SYMBOLIC_LIB[fun_name][3] - return model_load - - def copy(self): - ''' - deepcopy - - Args: - ----- - path : str - the path where checkpoints are saved - - Returns: - -------- - MultKAN - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) - >>> model2 = model.copy() - >>> model2.act_fun[0].coef.data *= 2 - >>> print(model2.act_fun[0].coef.data) - >>> print(model.act_fun[0].coef.data) - ''' - path='copy_temp' - self.saveckpt(path) - return KAN.loadckpt(path) - - def rewind(self, model_id): - ''' - rewind to an old version - - Args: - ----- - model_id : str - in format '{a}.{b}' where a is the round number, b is the version number in that round - - Returns: - -------- - MultKAN - - Example - ------- - Please refer to tutorials. API 12: Checkpoint, save & load model - ''' - self.round += 1 - self.state_id = model_id.split('.')[-1] - - history_path = self.ckpt_path+'/history.txt' - with open(history_path, 'a') as file: - file.write(f'### Round {self.round} ###' + '\n') - - self.saveckpt(path=self.ckpt_path+'/'+f'{self.round}.{self.state_id}') - - print('rewind to model version '+f'{self.round-1}.{self.state_id}'+', renamed as '+f'{self.round}.{self.state_id}') - - return MultKAN.loadckpt(path=self.ckpt_path+'/'+str(model_id)) - - - def checkout(self, model_id): - ''' - check out an old version - - Args: - ----- - model_id : str - in format '{a}.{b}' where a is the round number, b is the version number in that round - - Returns: - -------- - MultKAN - - Example - ------- - Same use as rewind, although checkout doesn't change states - ''' - return MultKAN.loadckpt(path=self.ckpt_path+'/'+str(model_id)) - - def update_grid_from_samples(self, x): - ''' - update grid from samples - - Args: - ----- - x : 2D torch.tensor - inputs - - Returns: - -------- - None - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) - >>> print(model.act_fun[0].grid) - >>> x = torch.linspace(-10,10,steps=101)[:,None] - >>> model.update_grid_from_samples(x) - >>> print(model.act_fun[0].grid) - ''' - for l in range(self.depth): - self.get_act(x) - self.act_fun[l].update_grid_from_samples(self.acts[l]) - - def update_grid(self, x): - ''' - call update_grid_from_samples. This seems unnecessary but we retain it for the sake of classes that might inherit from MultKAN - ''' - self.update_grid_from_samples(x) - - def initialize_grid_from_another_model(self, model, x): - ''' - initialize grid from another model - - Args: - ----- - model : MultKAN - parent model - x : 2D torch.tensor - inputs - - Returns: - -------- - None - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) - >>> print(model.act_fun[0].grid) - >>> x = torch.linspace(-10,10,steps=101)[:,None] - >>> model2 = KAN(width=[1,1], grid=10, k=3, seed=0) - >>> model2.initialize_grid_from_another_model(model, x) - >>> print(model2.act_fun[0].grid) - ''' - model(x) - for l in range(self.depth): - self.act_fun[l].initialize_grid_from_parent(model.act_fun[l], model.acts[l]) - - def forward(self, x, singularity_avoiding=False, y_th=10.): - ''' - forward pass - - Args: - ----- - x : 2D torch.tensor - inputs - singularity_avoiding : bool - whether to avoid singularity for the symbolic branch - y_th : float - the threshold for singularity - - Returns: - -------- - None - - Example1 - -------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) - >>> x = torch.rand(100,2) - >>> model(x).shape - - Example2 - -------- - >>> from kan import * - >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) - >>> x = torch.tensor([[1],[-0.01]]) - >>> model.fix_symbolic(0,0,0,'log',fit_params_bool=False) - >>> print(model(x)) - >>> print(model(x, singularity_avoiding=True)) - >>> print(model(x, singularity_avoiding=True, y_th=1.)) - ''' - x = x[:,self.input_id.long()] - assert x.shape[1] == self.width_in[0] - - # cache data - self.cache_data = x - - self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L]) - self.acts_premult = [] - self.spline_preacts = [] - self.spline_postsplines = [] - self.spline_postacts = [] - self.acts_scale = [] - self.acts_scale_spline = [] - self.subnode_actscale = [] - self.edge_actscale = [] - # self.neurons_scale = [] - - self.acts.append(x) # acts shape: (batch, width[l]) - - for l in range(self.depth): - - x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x) - #print(preacts, postacts_numerical, postspline) - - if self.symbolic_enabled == True: - x_symbolic, postacts_symbolic = self.symbolic_fun[l](x, singularity_avoiding=singularity_avoiding, y_th=y_th) - else: - x_symbolic = 0. - postacts_symbolic = 0. - - x = x_numerical + x_symbolic - - if self.save_act: - # save subnode_scale - self.subnode_actscale.append(torch.std(x, dim=0).detach()) - - # subnode affine transform - x = self.subnode_scale[l][None,:] * x + self.subnode_bias[l][None,:] - - if self.save_act: - postacts = postacts_numerical + postacts_symbolic - - # self.neurons_scale.append(torch.mean(torch.abs(x), dim=0)) - #grid_reshape = self.act_fun[l].grid.reshape(self.width_out[l + 1], self.width_in[l], -1) - input_range = torch.std(preacts, dim=0) + 0.1 - output_range_spline = torch.std(postacts_numerical, dim=0) # for training, only penalize the spline part - output_range = torch.std(postacts, dim=0) # for visualization, include the contribution from both spline + symbolic - # save edge_scale - self.edge_actscale.append(output_range) - - self.acts_scale.append((output_range / input_range).detach()) - self.acts_scale_spline.append(output_range_spline / input_range) - self.spline_preacts.append(preacts.detach()) - self.spline_postacts.append(postacts.detach()) - self.spline_postsplines.append(postspline.detach()) - - self.acts_premult.append(x.detach()) - - # multiplication - dim_sum = self.width[l+1][0] - dim_mult = self.width[l+1][1] - - if self.mult_homo == True: - for i in range(self.mult_arity-1): - if i == 0: - x_mult = x[:,dim_sum::self.mult_arity] * x[:,dim_sum+1::self.mult_arity] - else: - x_mult = x_mult * x[:,dim_sum+i+1::self.mult_arity] - - else: - for j in range(dim_mult): - acml_id = dim_sum + np.sum(self.mult_arity[l+1][:j]) - for i in range(self.mult_arity[l+1][j]-1): - if i == 0: - x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]] - else: - x_mult_j = x_mult_j * x[:,[acml_id+i+1]] - - if j == 0: - x_mult = x_mult_j - else: - x_mult = torch.cat([x_mult, x_mult_j], dim=1) - - if self.width[l+1][1] > 0: - x = torch.cat([x[:,:dim_sum], x_mult], dim=1) - - # x = x + self.biases[l].weight - # node affine transform - x = self.node_scale[l][None,:] * x + self.node_bias[l][None,:] - - self.acts.append(x.detach()) - - - return x - - def set_mode(self, l, i, j, mode, mask_n=None): - if mode == "s": - mask_n = 0.; - mask_s = 1. - elif mode == "n": - mask_n = 1.; - mask_s = 0. - elif mode == "sn" or mode == "ns": - if mask_n == None: - mask_n = 1. - else: - mask_n = mask_n - mask_s = 1. - else: - mask_n = 0.; - mask_s = 0. - - self.act_fun[l].mask.data[i][j] = mask_n - self.symbolic_fun[l].mask.data[j,i] = mask_s - - def fix_symbolic(self, l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True, random=False, log_history=True): - ''' - set (l,i,j) activation to be symbolic (specified by fun_name) - - Args: - ----- - l : int - layer index - i : int - input neuron index - j : int - output neuron index - fun_name : str - function name - fit_params_bool : bool - obtaining affine parameters through fitting (True) or setting default values (False) - a_range : tuple - sweeping range of a - b_range : tuple - sweeping range of b - verbose : bool - If True, more information is printed. - random : bool - initialize affine parameteres randomly or as [1,0,1,0] - log_history : bool - indicate whether to log history when the function is called - - Returns: - -------- - None or r2 (coefficient of determination) - - Example 1 - --------- - >>> # when fit_params_bool = False - >>> model = KAN(width=[2,5,1], grid=5, k=3) - >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False) - >>> print(model.act_fun[0].mask.reshape(2,5)) - >>> print(model.symbolic_fun[0].mask.reshape(2,5)) - - Example 2 - --------- - >>> # when fit_params_bool = True - >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.) - >>> x = torch.normal(0,1,size=(100,2)) - >>> model(x) # obtain activations (otherwise model does not have attributes acts) - >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True) - >>> print(model.act_fun[0].mask.reshape(2,5)) - >>> print(model.symbolic_fun[0].mask.reshape(2,5)) - ''' - if not fit_params_bool: - self.symbolic_fun[l].fix_symbolic(i, j, fun_name, verbose=verbose, random=random) - r2 = None - else: - x = self.acts[l][:, i] - mask = self.act_fun[l].mask - y = self.spline_postacts[l][:, j, i] - #y = self.postacts[l][:, j, i] - r2 = self.symbolic_fun[l].fix_symbolic(i, j, fun_name, x, y, a_range=a_range, b_range=b_range, verbose=verbose) - if mask[i,j] == 0: - r2 = - 1e8 - self.set_mode(l, i, j, mode="s") - - if log_history: - self.log_history('fix_symbolic') - return r2 - - def unfix_symbolic(self, l, i, j, log_history=True): - ''' - unfix the (l,i,j) activation function. - ''' - self.set_mode(l, i, j, mode="n") - self.symbolic_fun[l].funs_name[j][i] = "0" - if log_history: - self.log_history('unfix_symbolic') - - def unfix_symbolic_all(self, log_history=True): - ''' - unfix all activation functions. - ''' - for l in range(len(self.width) - 1): - for i in range(self.width_in[l]): - for j in range(self.width_out[l + 1]): - self.unfix_symbolic(l, i, j, log_history) - - def get_range(self, l, i, j, verbose=True): - ''' - Get the input range and output range of the (l,i,j) activation - - Args: - ----- - l : int - layer index - i : int - input neuron index - j : int - output neuron index - - Returns: - -------- - x_min : float - minimum of input - x_max : float - maximum of input - y_min : float - minimum of output - y_max : float - maximum of output - - Example - ------- - >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) - >>> x = torch.normal(0,1,size=(100,2)) - >>> model(x) # do a forward pass to obtain model.acts - >>> model.get_range(0,0,0) - ''' - x = self.spline_preacts[l][:, j, i] - y = self.spline_postacts[l][:, j, i] - x_min = torch.min(x).cpu().detach().numpy() - x_max = torch.max(x).cpu().detach().numpy() - y_min = torch.min(y).cpu().detach().numpy() - y_max = torch.max(y).cpu().detach().numpy() - if verbose: - print('x range: [' + '%.2f' % x_min, ',', '%.2f' % x_max, ']') - print('y range: [' + '%.2f' % y_min, ',', '%.2f' % y_max, ']') - return x_min, x_max, y_min, y_max - - def plot(self, folder="./figures", beta=3, metric='backward', scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None, varscale=1.0): - ''' - plot KAN - - Args: - ----- - folder : str - the folder to store pngs - beta : float - positive number. control the transparency of each activation. transparency = tanh(beta*l1). - mask : bool - If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions. - mode : bool - "supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean). - scale : float - control the size of the diagram - in_vars: None or list of str - the name(s) of input variables - out_vars: None or list of str - the name(s) of output variables - title: None or str - title - varscale : float - the size of input variables - - Returns: - -------- - Figure - - Example - ------- - >>> # see more interactive examples in demos - >>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0) - >>> x = torch.normal(0,1,size=(100,2)) - >>> model(x) # do a forward pass to obtain model.acts - >>> model.plot() - ''' - global Symbol - - if not self.save_act: - print('cannot plot since data are not saved. Set save_act=True first.') - - # forward to obtain activations - if self.acts == None: - if self.cache_data == None: - raise Exception('model hasn\'t seen any data yet.') - self.forward(self.cache_data) - - if metric == 'backward': - self.attribute() - - - if not os.path.exists(folder): - os.makedirs(folder) - # matplotlib.use('Agg') - depth = len(self.width) - 1 - for l in range(depth): - w_large = 2.0 - for i in range(self.width_in[l]): - for j in range(self.width_out[l+1]): - rank = torch.argsort(self.acts[l][:, i]) - fig, ax = plt.subplots(figsize=(w_large, w_large)) - - num = rank.shape[0] - - #print(self.width_in[l]) - #print(self.width_out[l+1]) - symbolic_mask = self.symbolic_fun[l].mask[j][i] - numeric_mask = self.act_fun[l].mask[i][j] - if symbolic_mask > 0. and numeric_mask > 0.: - color = 'purple' - alpha_mask = 1 - if symbolic_mask > 0. and numeric_mask == 0.: - color = "red" - alpha_mask = 1 - if symbolic_mask == 0. and numeric_mask > 0.: - color = "black" - alpha_mask = 1 - if symbolic_mask == 0. and numeric_mask == 0.: - color = "white" - alpha_mask = 0 - - - if tick == True: - ax.tick_params(axis="y", direction="in", pad=-22, labelsize=50) - ax.tick_params(axis="x", direction="in", pad=-15, labelsize=50) - x_min, x_max, y_min, y_max = self.get_range(l, i, j, verbose=False) - plt.xticks([x_min, x_max], ['%2.f' % x_min, '%2.f' % x_max]) - plt.yticks([y_min, y_max], ['%2.f' % y_min, '%2.f' % y_max]) - else: - plt.xticks([]) - plt.yticks([]) - if alpha_mask == 1: - plt.gca().patch.set_edgecolor('black') - else: - plt.gca().patch.set_edgecolor('white') - plt.gca().patch.set_linewidth(1.5) - # plt.axis('off') - - plt.plot(self.acts[l][:, i][rank].cpu().detach().numpy(), self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, lw=5) - if sample == True: - plt.scatter(self.acts[l][:, i][rank].cpu().detach().numpy(), self.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, s=400 * scale ** 2) - plt.gca().spines[:].set_color(color) - - plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400) - plt.close() - - def score2alpha(score): - return np.tanh(beta * score) - - - if metric == 'forward_n': - scores = self.acts_scale - elif metric == 'forward_u': - scores = self.edge_actscale - elif metric == 'backward': - scores = self.edge_scores - else: - raise Exception(f'metric = \'{metric}\' not recognized') - - alpha = [score2alpha(score.cpu().detach().numpy()) for score in scores] - - # draw skeleton - width = np.array(self.width) - width_in = np.array(self.width_in) - width_out = np.array(self.width_out) - A = 1 - y0 = 0.3 # height: from input to pre-mult - z0 = 0.1 # height: from pre-mult to post-mult (input of next layer) - - neuron_depth = len(width) - min_spacing = A / np.maximum(np.max(width_out), 5) - - max_neuron = np.max(width_out) - max_num_weights = np.max(width_in[:-1] * width_out[1:]) - y1 = 0.4 / np.maximum(max_num_weights, 5) # size (height/width) of 1D function diagrams - y2 = 0.15 / np.maximum(max_neuron, 5) # size (height/width) of operations (sum and mult) - - fig, ax = plt.subplots(figsize=(10 * scale, 10 * scale * (neuron_depth - 1) * (y0+z0))) - # fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0)) - - # -- Transformation functions - DC_to_FC = ax.transData.transform - FC_to_NFC = fig.transFigure.inverted().transform - # -- Take data coordinates and transform them to normalized figure coordinates - DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x)) - - # plot scatters and lines - for l in range(neuron_depth): - - n = width_in[l] - - # scatters - for i in range(n): - plt.scatter(1 / (2 * n) + i / n, l * (y0+z0), s=min_spacing ** 2 * 10000 * scale ** 2, color='black') - - # plot connections (input to pre-mult) - for i in range(n): - if l < neuron_depth - 1: - n_next = width_out[l+1] - N = n * n_next - for j in range(n_next): - id_ = i * n_next + j - - symbol_mask = self.symbolic_fun[l].mask[j][i] - numerical_mask = self.act_fun[l].mask[i][j] - if symbol_mask == 1. and numerical_mask > 0.: - color = 'purple' - alpha_mask = 1. - if symbol_mask == 1. and numerical_mask == 0.: - color = "red" - alpha_mask = 1. - if symbol_mask == 0. and numerical_mask == 1.: - color = "black" - alpha_mask = 1. - if symbol_mask == 0. and numerical_mask == 0.: - color = "white" - alpha_mask = 0. - - plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N], [l * (y0+z0), l * (y0+z0) + y0/2 - y1], color=color, lw=2 * scale, alpha=alpha[l][j][i] * alpha_mask) - plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next], [l * (y0+z0) + y0/2 + y1, l * (y0+z0)+y0], color=color, lw=2 * scale, alpha=alpha[l][j][i] * alpha_mask) - - - # plot connections (pre-mult to post-mult, post-mult = next-layer input) - if l < neuron_depth - 1: - n_in = width_out[l+1] - n_out = width_in[l+1] - mult_id = 0 - for i in range(n_in): - if i < width[l+1][0]: - j = i - else: - if i == width[l+1][0]: - if isinstance(self.mult_arity,int): - ma = self.mult_arity - else: - ma = self.mult_arity[l+1][mult_id] - current_mult_arity = ma - if current_mult_arity == 0: - mult_id += 1 - if isinstance(self.mult_arity,int): - ma = self.mult_arity - else: - ma = self.mult_arity[l+1][mult_id] - current_mult_arity = ma - j = width[l+1][0] + mult_id - current_mult_arity -= 1 - #j = (i-width[l+1][0])//self.mult_arity + width[l+1][0] - plt.plot([1 / (2 * n_in) + i / n_in, 1 / (2 * n_out) + j / n_out], [l * (y0+z0) + y0, (l+1) * (y0+z0)], color='black', lw=2 * scale) - - - - plt.xlim(0, 1) - plt.ylim(-0.1 * (y0+z0), (neuron_depth - 1 + 0.1) * (y0+z0)) - - - plt.axis('off') - - for l in range(neuron_depth - 1): - # plot splines - n = width_in[l] - for i in range(n): - n_next = width_out[l + 1] - N = n * n_next - for j in range(n_next): - id_ = i * n_next + j - im = plt.imread(f'{folder}/sp_{l}_{i}_{j}.png') - left = DC_to_NFC([1 / (2 * N) + id_ / N - y1, 0])[0] - right = DC_to_NFC([1 / (2 * N) + id_ / N + y1, 0])[0] - bottom = DC_to_NFC([0, l * (y0+z0) + y0/2 - y1])[1] - up = DC_to_NFC([0, l * (y0+z0) + y0/2 + y1])[1] - newax = fig.add_axes([left, bottom, right - left, up - bottom]) - # newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE') - newax.imshow(im, alpha=alpha[l][j][i]) - newax.axis('off') - - - # plot sum symbols - N = n = width_out[l+1] - for j in range(n): - id_ = j - path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png" - im = plt.imread(path) - left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0] - right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0] - bottom = DC_to_NFC([0, l * (y0+z0) + y0 - y2])[1] - up = DC_to_NFC([0, l * (y0+z0) + y0 + y2])[1] - newax = fig.add_axes([left, bottom, right - left, up - bottom]) - newax.imshow(im) - newax.axis('off') - - # plot mult symbols - N = n = width_in[l+1] - n_sum = width[l+1][0] - n_mult = width[l+1][1] - for j in range(n_mult): - id_ = j + n_sum - path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png" - im = plt.imread(path) - left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0] - right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0] - bottom = DC_to_NFC([0, (l+1) * (y0+z0) - y2])[1] - up = DC_to_NFC([0, (l+1) * (y0+z0) + y2])[1] - newax = fig.add_axes([left, bottom, right - left, up - bottom]) - newax.imshow(im) - newax.axis('off') - - if in_vars != None: - n = self.width_in[0] - for i in range(n): - if isinstance(in_vars[i], sympy.Expr): - plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, f'${latex(in_vars[i])}$', fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center') - else: - plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[i], fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center') - - - - if out_vars != None: - n = self.width_in[-1] - for i in range(n): - if isinstance(out_vars[i], sympy.Expr): - plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0+z0) * (len(self.width) - 1) + 0.15, f'${latex(out_vars[i])}$', fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center') - else: - plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), (y0+z0) * (len(self.width) - 1) + 0.15, out_vars[i], fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center') - - if title != None: - plt.gcf().get_axes()[0].text(0.5, (y0+z0) * (len(self.width) - 1) + 0.3, title, fontsize=40 * scale, horizontalalignment='center', verticalalignment='center') - - - def reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff): - ''' - Get regularization - - Args: - ----- - reg_metric : the regularization metric - 'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward' - lamb_l1 : float - l1 penalty strength - lamb_entropy : float - entropy penalty strength - lamb_coef : float - coefficient penalty strength - lamb_coefdiff : float - coefficient smoothness strength - - Returns: - -------- - reg_ : torch.float - - Example - ------- - >>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) - >>> x = torch.rand(100,2) - >>> model.get_act(x) - >>> model.reg('edge_forward_spline_n', 1.0, 2.0, 1.0, 1.0) - ''' - if reg_metric == 'edge_forward_spline_n': - acts_scale = self.acts_scale_spline - - elif reg_metric == 'edge_forward_sum': - acts_scale = self.acts_scale - - elif reg_metric == 'edge_forward_spline_u': - acts_scale = self.edge_actscale - - elif reg_metric == 'edge_backward': - acts_scale = self.edge_scores - - elif reg_metric == 'node_backward': - acts_scale = self.node_attribute_scores - - else: - raise Exception(f'reg_metric = {reg_metric} not recognized!') - - reg_ = 0. - for i in range(len(acts_scale)): - vec = acts_scale[i] - - l1 = torch.sum(vec) - p_row = vec / (torch.sum(vec, dim=1, keepdim=True) + 1) - p_col = vec / (torch.sum(vec, dim=0, keepdim=True) + 1) - entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1)) - entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0)) - reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col) # both l1 and entropy - - # regularize coefficient to encourage spline to be zero - for i in range(len(self.act_fun)): - coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1)) - coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)), dim=1)) - reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1 - - return reg_ - - def get_reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff): - ''' - Get regularization. This seems unnecessary but in case a class wants to inherit this, it may want to rewrite get_reg, but not reg. - ''' - return self.reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff) - - def disable_symbolic_in_fit(self, lamb): - ''' - during fitting, disable symbolic if either is true (lamb = 0, none of symbolic functions is active) - ''' - old_save_act = self.save_act - if lamb == 0.: - self.save_act = False - - # skip symbolic if no symbolic is turned on - depth = len(self.symbolic_fun) - no_symbolic = True - for l in range(depth): - no_symbolic *= torch.sum(torch.abs(self.symbolic_fun[l].mask)) == 0 - - old_symbolic_enabled = self.symbolic_enabled - - if no_symbolic: - self.symbolic_enabled = False - - return old_save_act, old_symbolic_enabled - - def get_params(self): - ''' - Get parameters - ''' - return self.parameters() - - - def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1.,start_grid_update_step=-1, stop_grid_update_step=50, batch=-1, - metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None): - ''' - training - - Args: - ----- - dataset : dic - contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label'] - opt : str - "LBFGS" or "Adam" - steps : int - training steps - log : int - logging frequency - lamb : float - overall penalty strength - lamb_l1 : float - l1 penalty strength - lamb_entropy : float - entropy penalty strength - lamb_coef : float - coefficient magnitude penalty strength - lamb_coefdiff : float - difference of nearby coefficits (smoothness) penalty strength - update_grid : bool - If True, update grid regularly before stop_grid_update_step - grid_update_num : int - the number of grid updates before stop_grid_update_step - start_grid_update_step : int - no grid updates before this training step - stop_grid_update_step : int - no grid updates after this training step - loss_fn : function - loss function - lr : float - learning rate - batch : int - batch size, if -1 then full. - save_fig_freq : int - save figure every (save_fig_freq) steps - singularity_avoiding : bool - indicate whether to avoid singularity for the symbolic part - y_th : float - singularity threshold (anything above the threshold is considered singular and is softened in some ways) - reg_metric : str - regularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'} - metrics : a list of metrics (as functions) - the metrics to be computed in training - display_metrics : a list of functions - the metric to be displayed in tqdm progress bar - - Returns: - -------- - results : dic - results['train_loss'], 1D array of training losses (RMSE) - results['test_loss'], 1D array of test losses (RMSE) - results['reg'], 1D array of regularization - other metrics specified in metrics - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=2) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.plot() - # Most examples in toturals involve the fit() method. Please check them for useness. - ''' - - if lamb > 0. and not self.save_act: - print('setting lamb=0. If you want to set lamb > 0, set self.save_act=True') - - old_save_act, old_symbolic_enabled = self.disable_symbolic_in_fit(lamb) - - pbar = tqdm(range(steps), desc='description', ncols=100) - - if loss_fn == None: - loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2) - else: - loss_fn = loss_fn_eval = loss_fn - - grid_update_freq = int(stop_grid_update_step / grid_update_num) - - if opt == "Adam": - optimizer = torch.optim.Adam(self.get_params(), lr=lr) - elif opt == "LBFGS": - optimizer = LBFGS(self.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) - - results = {} - results['train_loss'] = [] - results['test_loss'] = [] - results['reg'] = [] - if metrics != None: - for i in range(len(metrics)): - results[metrics[i].__name__] = [] - - if batch == -1 or batch > dataset['train_input'].shape[0]: - batch_size = dataset['train_input'].shape[0] - batch_size_test = dataset['test_input'].shape[0] - else: - batch_size = batch - batch_size_test = batch - - global train_loss, reg_ - - def closure(): - global train_loss, reg_ - optimizer.zero_grad() - pred = self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding, y_th=y_th) - train_loss = loss_fn(pred, dataset['train_label'][train_id]) - if self.save_act: - if reg_metric == 'edge_backward': - self.attribute() - if reg_metric == 'node_backward': - self.node_attribute() - reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff) - else: - reg_ = torch.tensor(0.) - objective = train_loss + lamb * reg_ - objective.backward() - return objective - - if save_fig: - if not os.path.exists(img_folder): - os.makedirs(img_folder) - - for _ in pbar: - - if _ == steps-1 and old_save_act: - self.save_act = True - - if save_fig and _ % save_fig_freq == 0: - save_act = self.save_act - self.save_act = True - - train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False) - test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False) - - if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ >= start_grid_update_step: - self.update_grid(dataset['train_input'][train_id]) - - if opt == "LBFGS": - optimizer.step(closure) - - if opt == "Adam": - pred = self.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding, y_th=y_th) - train_loss = loss_fn(pred, dataset['train_label'][train_id]) - if self.save_act: - if reg_metric == 'edge_backward': - self.attribute() - if reg_metric == 'node_backward': - self.node_attribute() - reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff) - else: - reg_ = torch.tensor(0.) - loss = train_loss + lamb * reg_ - optimizer.zero_grad() - loss.backward() - optimizer.step() - - test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id]), dataset['test_label'][test_id]) - - - if metrics != None: - for i in range(len(metrics)): - results[metrics[i].__name__].append(metrics[i]().item()) - - results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy()) - results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy()) - results['reg'].append(reg_.cpu().detach().numpy()) - - if _ % log == 0: - if display_metrics == None: - pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy())) - else: - string = '' - data = () - for metric in display_metrics: - string += f' {metric}: %.2e |' - try: - results[metric] - except: - raise Exception(f'{metric} not recognized') - data += (results[metric][-1],) - pbar.set_description(string % data) - - - if save_fig and _ % save_fig_freq == 0: - self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta) - plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=200) - plt.close() - self.save_act = save_act - - self.log_history('fit') - # revert back to original state - self.symbolic_enabled = old_symbolic_enabled - return results - - def prune_node(self, threshold=1e-2, mode="auto", active_neurons_id=None, log_history=True): - ''' - pruning nodes - - Args: - ----- - threshold : float - if the attribution score of a neuron is below the threshold, it is considered dead and will be removed - mode : str - 'auto' or 'manual'. with 'auto', nodes are automatically pruned using threshold. with 'manual', active_neurons_id should be passed in. - - Returns: - -------- - pruned network : MultKAN - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=2) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model = model.prune_node() - >>> model.plot() - ''' - if self.acts == None: - self.get_act() - - mask_up = [torch.ones(self.width_in[0], device=self.device)] - mask_down = [] - active_neurons_up = [list(range(self.width_in[0]))] - active_neurons_down = [] - num_sums = [] - num_mults = [] - mult_arities = [[]] - - if active_neurons_id != None: - mode = "manual" - - for i in range(len(self.acts_scale) - 1): - - mult_arity = [] - - if mode == "auto": - self.attribute() - overall_important_up = self.node_scores[i+1] > threshold - - elif mode == "manual": - overall_important_up = torch.zeros(self.width_in[i + 1], dtype=torch.bool, device=self.device) - overall_important_up[active_neurons_id[i]] = True - - - num_sum = torch.sum(overall_important_up[:self.width[i+1][0]]) - num_mult = torch.sum(overall_important_up[self.width[i+1][0]:]) - if self.mult_homo == True: - overall_important_down = torch.cat([overall_important_up[:self.width[i+1][0]], (overall_important_up[self.width[i+1][0]:][None,:].expand(self.mult_arity,-1)).T.reshape(-1,)], dim=0) - else: - overall_important_down = overall_important_up[:self.width[i+1][0]] - for j in range(overall_important_up[self.width[i+1][0]:].shape[0]): - active_bool = overall_important_up[self.width[i+1][0]+j] - arity = self.mult_arity[i+1][j] - overall_important_down = torch.cat([overall_important_down, torch.tensor([active_bool]*arity).to(self.device)]) - if active_bool: - mult_arity.append(arity) - - num_sums.append(num_sum.item()) - num_mults.append(num_mult.item()) - - mask_up.append(overall_important_up.float()) - mask_down.append(overall_important_down.float()) - - active_neurons_up.append(torch.where(overall_important_up == True)[0]) - active_neurons_down.append(torch.where(overall_important_down == True)[0]) - - mult_arities.append(mult_arity) - - active_neurons_down.append(list(range(self.width_out[-1]))) - mask_down.append(torch.ones(self.width_out[-1], device=self.device)) - - if self.mult_homo == False: - mult_arities.append(self.mult_arity[-1]) - - self.mask_up = mask_up - self.mask_down = mask_down - - # update act_fun[l].mask up - for l in range(len(self.acts_scale) - 1): - for i in range(self.width_in[l + 1]): - if i not in active_neurons_up[l + 1]: - self.remove_node(l + 1, i, mode='up',log_history=False) - - for i in range(self.width_out[l + 1]): - if i not in active_neurons_down[l]: - self.remove_node(l + 1, i, mode='down',log_history=False) - - model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun_name, mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=True, first_init=False, state_id=self.state_id, round=self.round).to(self.device) - model2.load_state_dict(self.state_dict()) - - width_new = [self.width[0]] - - for i in range(len(self.acts_scale)): - - if i < len(self.acts_scale) - 1: - num_sum = num_sums[i] - num_mult = num_mults[i] - model2.node_bias[i].data = model2.node_bias[i].data[active_neurons_up[i+1]] - model2.node_scale[i].data = model2.node_scale[i].data[active_neurons_up[i+1]] - model2.subnode_bias[i].data = model2.subnode_bias[i].data[active_neurons_down[i]] - model2.subnode_scale[i].data = model2.subnode_scale[i].data[active_neurons_down[i]] - model2.width[i+1] = [num_sum, num_mult] - - model2.act_fun[i].out_dim_sum = num_sum - model2.act_fun[i].out_dim_mult = num_mult - - model2.symbolic_fun[i].out_dim_sum = num_sum - model2.symbolic_fun[i].out_dim_mult = num_mult - - width_new.append([num_sum, num_mult]) - - model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i]) - model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(active_neurons_up[i], active_neurons_down[i]) - - model2.cache_data = self.cache_data - model2.acts = None - - width_new.append(self.width[-1]) - model2.width = width_new - - if self.mult_homo == False: - model2.mult_arity = mult_arities - - if log_history: - self.log_history('prune_node') - model2.state_id += 1 - - return model2 - - def prune_edge(self, threshold=3e-2, log_history=True): - ''' - pruning edges - - Args: - ----- - threshold : float - if the attribution score of an edge is below the threshold, it is considered dead and will be set to zero. - - Returns: - -------- - pruned network : MultKAN - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=2) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model = model.prune_edge() - >>> model.plot() - ''' - if self.acts == None: - self.get_act() - - for i in range(len(self.width)-1): - #self.act_fun[i].mask.data = ((self.acts_scale[i] > threshold).permute(1,0)).float() - old_mask = self.act_fun[i].mask.data - self.act_fun[i].mask.data = ((self.edge_scores[i] > threshold).permute(1,0)*old_mask).float() - - if log_history: - self.log_history('fix_symbolic') - - def prune(self, node_th=1e-2, edge_th=3e-2): - ''' - prune (both nodes and edges) - - Args: - ----- - node_th : float - if the attribution score of a node is below node_th, it is considered dead and will be set to zero. - edge_th : float - if the attribution score of an edge is below node_th, it is considered dead and will be set to zero. - - Returns: - -------- - pruned network : MultKAN - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=2) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model = model.prune() - >>> model.plot() - ''' - if self.acts == None: - self.get_act() - - self = self.prune_node(node_th, log_history=False) - #self.prune_node(node_th, log_history=False) - self.forward(self.cache_data) - self.attribute() - self.prune_edge(edge_th, log_history=False) - self.log_history('prune') - return self - - def prune_input(self, threshold=1e-2, active_inputs=None, log_history=True): - ''' - prune inputs - - Args: - ----- - threshold : float - if the attribution score of the input feature is below threshold, it is considered irrelevant. - active_inputs : None or list - if a list is passed, the manual mode will disregard attribution score and prune as instructed. - - Returns: - -------- - pruned network : MultKAN - - Example1 - -------- - >>> # automatic - >>> from kan import * - >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.plot() - >>> model = model.prune_input() - >>> model.plot() - - Example2 - -------- - >>> # automatic - >>> from kan import * - >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.plot() - >>> model = model.prune_input(active_inputs=[0,1]) - >>> model.plot() - ''' - if active_inputs == None: - self.attribute() - input_score = self.node_scores[0] - input_mask = input_score > threshold - print('keep:', input_mask.tolist()) - input_id = torch.where(input_mask==True)[0] - - else: - input_id = torch.tensor(active_inputs, dtype=torch.long).to(self.device) - - model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun, mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=True, first_init=False, state_id=self.state_id, round=self.round).to(self.device) - model2.load_state_dict(self.state_dict()) - - model2.act_fun[0] = model2.act_fun[0].get_subset(input_id, torch.arange(self.width_out[1])) - model2.symbolic_fun[0] = self.symbolic_fun[0].get_subset(input_id, torch.arange(self.width_out[1])) - - model2.cache_data = self.cache_data - model2.acts = None - - model2.width[0] = [len(input_id), 0] - model2.input_id = input_id - - if log_history: - self.log_history('prune_input') - model2.state_id += 1 - - return model2 - - def remove_edge(self, l, i, j, log_history=True): - ''' - remove activtion phi(l,i,j) (set its mask to zero) - ''' - self.act_fun[l].mask[i][j] = 0. - if log_history: - self.log_history('remove_edge') - - def remove_node(self, l ,i, mode='all', log_history=True): - ''' - remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero) - ''' - if mode == 'down': - self.act_fun[l - 1].mask[:, i] = 0. - self.symbolic_fun[l - 1].mask[i, :] *= 0. - - elif mode == 'up': - self.act_fun[l].mask[i, :] = 0. - self.symbolic_fun[l].mask[:, i] *= 0. - - else: - self.remove_node(l, i, mode='up') - self.remove_node(l, i, mode='down') - - if log_history: - self.log_history('remove_node') - - - def attribute(self, l=None, i=None, out_score=None, plot=True): - ''' - get attribution scores - - Args: - ----- - l : None or int - layer index - i : None or int - neuron index - out_score : None or 1D torch.float - specify output scores - plot : bool - when plot = True, display the bar show - - Returns: - -------- - attribution scores - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.attribute() - >>> model.feature_score - ''' - # output (out_dim, in_dim) - - if l != None: - self.attribute() - out_score = self.node_scores[l] - - if self.acts == None: - self.get_act() - - def score_node2subnode(node_score, width, mult_arity, out_dim): - - assert np.sum(width) == node_score.shape[1] - if isinstance(mult_arity, int): - n_subnode = width[0] + mult_arity * width[1] - else: - n_subnode = width[0] + int(np.sum(mult_arity)) - - #subnode_score_leaf = torch.zeros(out_dim, n_subnode).requires_grad_(True) - #subnode_score = subnode_score_leaf.clone() - #subnode_score[:,:width[0]] = node_score[:,:width[0]] - subnode_score = node_score[:,:width[0]] - if isinstance(mult_arity, int): - #subnode_score[:,width[0]:] = node_score[:,width[0]:][:,:,None].expand(out_dim, node_score[width[0]:].shape[0], mult_arity).reshape(out_dim,-1) - subnode_score = torch.cat([subnode_score, node_score[:,width[0]:][:,:,None].expand(out_dim, node_score[:,width[0]:].shape[1], mult_arity).reshape(out_dim,-1)], dim=1) - else: - acml = width[0] - for i in range(len(mult_arity)): - #subnode_score[:, acml:acml+mult_arity[i]] = node_score[:, width[0]+i] - subnode_score = torch.cat([subnode_score, node_score[:, width[0]+i].expand(out_dim, mult_arity[i])], dim=1) - acml += mult_arity[i] - return subnode_score - - - node_scores = [] - subnode_scores = [] - edge_scores = [] - - l_query = l - if l == None: - l_end = self.depth - else: - l_end = l - - # back propagate from the queried layer - out_dim = self.width_in[l_end] - if out_score == None: - node_score = torch.eye(out_dim).requires_grad_(True) - else: - node_score = torch.diag(out_score).requires_grad_(True) - node_scores.append(node_score) - - device = self.act_fun[0].grid.device - - for l in range(l_end,0,-1): - - # node to subnode - if isinstance(self.mult_arity, int): - subnode_score = score_node2subnode(node_score, self.width[l], self.mult_arity, out_dim=out_dim) - else: - mult_arity = self.mult_arity[l] - #subnode_score = score_node2subnode(node_score, self.width[l], mult_arity) - subnode_score = score_node2subnode(node_score, self.width[l], mult_arity, out_dim=out_dim) - - subnode_scores.append(subnode_score) - # subnode to edge - #print(self.edge_actscale[l-1].device, subnode_score.device, self.subnode_actscale[l-1].device) - edge_score = torch.einsum('ij,ki,i->kij', self.edge_actscale[l-1], subnode_score.to(device), 1/(self.subnode_actscale[l-1]+1e-4)) - edge_scores.append(edge_score) - - # edge to node - node_score = torch.sum(edge_score, dim=1) - node_scores.append(node_score) - - self.node_scores_all = list(reversed(node_scores)) - self.edge_scores_all = list(reversed(edge_scores)) - self.subnode_scores_all = list(reversed(subnode_scores)) - - self.node_scores = [torch.mean(l, dim=0) for l in self.node_scores_all] - self.edge_scores = [torch.mean(l, dim=0) for l in self.edge_scores_all] - self.subnode_scores = [torch.mean(l, dim=0) for l in self.subnode_scores_all] - - # return - if l_query != None: - if i == None: - return self.node_scores_all[0] - else: - - # plot - if plot: - in_dim = self.width_in[0] - plt.figure(figsize=(1*in_dim, 3)) - plt.bar(range(in_dim),self.node_scores_all[0][i].cpu().detach().numpy()) - plt.xticks(range(in_dim)); - - return self.node_scores_all[0][i] - - def node_attribute(self): - self.node_attribute_scores = [] - for l in range(1, self.depth+1): - node_attr = self.attribute(l) - self.node_attribute_scores.append(node_attr) - - def feature_interaction(self, l, neuron_th = 1e-2, feature_th = 1e-2): - ''' - get feature interaction - - Args: - ----- - l : int - layer index - neuron_th : float - threshold to determine whether a neuron is active - feature_th : float - threshold to determine whether a feature is active - - Returns: - -------- - dictionary - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) - >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.attribute() - >>> model.feature_interaction(1) - ''' - dic = {} - width = self.width_in[l] - - for i in range(width): - score = self.attribute(l,i,plot=False) - - if torch.max(score) > neuron_th: - features = tuple(torch.where(score > torch.max(score) * feature_th)[0].detach().numpy()) - if features in dic.keys(): - dic[features] += 1 - else: - dic[features] = 1 - - return dic - - def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True, r2_loss_fun=lambda x: np.log2(1+1e-5-x), c_loss_fun=lambda x: x, weight_simple = 0.8): - ''' - suggest symbolic function - - Args: - ----- - l : int - layer index - i : int - neuron index in layer l - j : int - neuron index in layer j - a_range : tuple - search range of a - b_range : tuple - search range of b - lib : list of str - library of candidate symbolic functions - topk : int - the number of top functions displayed - verbose : bool - if verbose = True, print more information - r2_loss_fun : functoon - function : r2 -> "bits" - c_loss_fun : fun - function : c -> 'bits' - weight_simple : float - the simplifty weight: the higher, more prefer simplicity over performance - - - Returns: - -------- - best_name (str), best_fun (function), best_r2 (float), best_c (float) - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.suggest_symbolic(0,1,0) - ''' - r2s = [] - cs = [] - - if lib == None: - symbolic_lib = SYMBOLIC_LIB - else: - symbolic_lib = {} - for item in lib: - symbolic_lib[item] = SYMBOLIC_LIB[item] - - # getting r2 and complexities - for (name, content) in symbolic_lib.items(): - r2 = self.fix_symbolic(l, i, j, name, a_range=a_range, b_range=b_range, verbose=False, log_history=False) - if r2 == -1e8: # zero function - r2s.append(-1e8) - else: - r2s.append(r2.item()) - self.unfix_symbolic(l, i, j, log_history=False) - c = content[2] - cs.append(c) - - r2s = np.array(r2s) - cs = np.array(cs) - r2_loss = r2_loss_fun(r2s).astype('float') - cs_loss = c_loss_fun(cs) - - loss = weight_simple * cs_loss + (1-weight_simple) * r2_loss - - sorted_ids = np.argsort(loss)[:topk] - r2s = r2s[sorted_ids][:topk] - cs = cs[sorted_ids][:topk] - r2_loss = r2_loss[sorted_ids][:topk] - cs_loss = cs_loss[sorted_ids][:topk] - loss = loss[sorted_ids][:topk] - - topk = np.minimum(topk, len(symbolic_lib)) - - if verbose == True: - # print results in a dataframe - results = {} - results['function'] = [list(symbolic_lib.items())[sorted_ids[i]][0] for i in range(topk)] - results['fitting r2'] = r2s[:topk] - results['r2 loss'] = r2_loss[:topk] - results['complexity'] = cs[:topk] - results['complexity loss'] = cs_loss[:topk] - results['total loss'] = loss[:topk] - - df = pd.DataFrame(results) - print(df) - - best_name = list(symbolic_lib.items())[sorted_ids[0]][0] - best_fun = list(symbolic_lib.items())[sorted_ids[0]][1] - best_r2 = r2s[0] - best_c = cs[0] - - return best_name, best_fun, best_r2, best_c; - - def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple = 0.8, r2_threshold=0.0): - ''' - automatic symbolic regression for all edges - - Args: - ----- - a_range : tuple - search range of a - b_range : tuple - search range of b - lib : list of str - library of candidate symbolic functions - verbose : int - larger verbosity => more verbosity - weight_simple : float - a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity - r2_threshold : float - If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold - Returns: - -------- - None - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.auto_symbolic() - ''' - for l in range(len(self.width_in) - 1): - for i in range(self.width_in[l]): - for j in range(self.width_out[l + 1]): - if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.: - print(f'skipping ({l},{i},{j}) since already symbolic') - elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.: - self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False) - print(f'fixing ({l},{i},{j}) with 0') - else: - name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False, weight_simple=weight_simple) - if r2 >= r2_threshold: - self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False) - if verbose >= 1: - print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}') - else: - print(f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.') - - self.log_history('auto_symbolic') - - def symbolic_formula(self, var=None, normalizer=None, output_normalizer = None): - ''' - get symbolic formula - - Args: - ----- - var : None or a list of sympy expression - input variables - normalizer : [mean, std] - output_normalizer : [mean, std] - - Returns: - -------- - None - - Example - ------- - >>> from kan import * - >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=3) - >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); - >>> model.auto_symbolic() - >>> model.symbolic_formula()[0][0] - ''' - - symbolic_acts = [] - symbolic_acts_premult = [] - x = [] - - def ex_round(ex1, n_digit): - ex2 = ex1 - for a in sympy.preorder_traversal(ex1): - if isinstance(a, sympy.Float): - ex2 = ex2.subs(a, round(a, n_digit)) - return ex2 - - # define variables - if var == None: - for ii in range(1, self.width[0][0] + 1): - exec(f"x{ii} = sympy.Symbol('x_{ii}')") - exec(f"x.append(x{ii})") - elif isinstance(var[0], sympy.Expr): - x = var - else: - x = [sympy.symbols(var_) for var_ in var] - - x0 = x - - if normalizer != None: - mean = normalizer[0] - std = normalizer[1] - x = [(x[i] - mean[i]) / std[i] for i in range(len(x))] - - symbolic_acts.append(x) - - for l in range(len(self.width_in) - 1): - num_sum = self.width[l + 1][0] - num_mult = self.width[l + 1][1] - y = [] - for j in range(self.width_out[l + 1]): - yj = 0. - for i in range(self.width_in[l]): - a, b, c, d = self.symbolic_fun[l].affine[j, i] - sympy_fun = self.symbolic_fun[l].funs_sympy[j][i] - try: - yj += c * sympy_fun(a * x[i] + b) + d - except: - print('make sure all activations need to be converted to symbolic formulas first!') - return - yj = self.subnode_scale[l][j] * yj + self.subnode_bias[l][j] - if simplify == True: - y.append(sympy.simplify(yj)) - else: - y.append(yj) - - symbolic_acts_premult.append(y) - - mult = [] - for k in range(num_mult): - if isinstance(self.mult_arity, int): - mult_arity = self.mult_arity - else: - mult_arity = self.mult_arity[l+1][k] - for i in range(mult_arity-1): - if i == 0: - mult_k = y[num_sum+2*k] * y[num_sum+2*k+1] - else: - mult_k = mult_k * y[num_sum+2*k+i+1] - mult.append(mult_k) - - y = y[:num_sum] + mult - - for j in range(self.width_in[l+1]): - y[j] = self.node_scale[l][j] * y[j] + self.node_bias[l][j] - - x = y - symbolic_acts.append(x) - - if output_normalizer != None: - output_layer = symbolic_acts[-1] - means = output_normalizer[0] - stds = output_normalizer[1] - - assert len(output_layer) == len(means), 'output_normalizer does not match the output layer' - assert len(output_layer) == len(stds), 'output_normalizer does not match the output layer' - - output_layer = [(output_layer[i] * stds[i] + means[i]) for i in range(len(output_layer))] - symbolic_acts[-1] = output_layer - - - self.symbolic_acts = [[symbolic_acts[l][i] for i in range(len(symbolic_acts[l]))] for l in range(len(symbolic_acts))] - self.symbolic_acts_premult = [[symbolic_acts_premult[l][i] for i in range(len(symbolic_acts_premult[l]))] for l in range(len(symbolic_acts_premult))] - - out_dim = len(symbolic_acts[-1]) - #return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0 - - if simplify: - return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0 - else: - return [symbolic_acts[-1][i] for i in range(len(symbolic_acts[-1]))], x0 - - - def expand_depth(self): - ''' - expand network depth, add an indentity layer to the end. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb. - - Args: - ----- - var : None or a list of sympy expression - input variables - normalizer : [mean, std] - output_normalizer : [mean, std] - - Returns: - -------- - None - ''' - self.depth += 1 - - # add kanlayer, set mask to zero - dim_out = self.width_in[-1] - layer = KANLayer(dim_out, dim_out, num=self.grid, k=self.k) - layer.mask *= 0. - self.act_fun.append(layer) - - self.width.append([dim_out, 0]) - self.mult_arity.append([]) - - # add symbolic_kanlayer set mask to one. fun = identity on diagonal and zero for off-diagonal - layer = Symbolic_KANLayer(dim_out, dim_out) - layer.mask += 1. - - for j in range(dim_out): - for i in range(dim_out): - if i == j: - layer.fix_symbolic(i,j,'x') - else: - layer.fix_symbolic(i,j,'0') - - self.symbolic_fun.append(layer) - - self.node_bias.append(torch.nn.Parameter(torch.zeros(dim_out,device=self.device)).requires_grad_(self.affine_trainable)) - self.node_scale.append(torch.nn.Parameter(torch.ones(dim_out,device=self.device)).requires_grad_(self.affine_trainable)) - self.subnode_bias.append(torch.nn.Parameter(torch.zeros(dim_out,device=self.device)).requires_grad_(self.affine_trainable)) - self.subnode_scale.append(torch.nn.Parameter(torch.ones(dim_out,device=self.device)).requires_grad_(self.affine_trainable)) - - def expand_width(self, layer_id, n_added_nodes, sum_bool=True, mult_arity=2): - ''' - expand network width. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb. - - Args: - ----- - layer_id : int - layer index - n_added_nodes : init - the number of added nodes - sum_bool : bool - if sum_bool == True, added nodes are addition nodes; otherwise multiplication nodes - mult_arity : init - multiplication arity (the number of numbers to be multiplied) - - Returns: - -------- - None - ''' - def _expand(layer_id, n_added_nodes, sum_bool=True, mult_arity=2, added_dim='out'): - l = layer_id - in_dim = self.symbolic_fun[l].in_dim - out_dim = self.symbolic_fun[l].out_dim - if sum_bool: - - if added_dim == 'out': - new = Symbolic_KANLayer(in_dim, out_dim + n_added_nodes) - old = self.symbolic_fun[l] - in_id = np.arange(in_dim) - out_id = np.arange(out_dim + n_added_nodes) - - for j in out_id: - for i in in_id: - new.fix_symbolic(i,j,'0') - new.mask += 1. - - for j in out_id: - for i in in_id: - if j > n_added_nodes-1: - new.funs[j][i] = old.funs[j-n_added_nodes][i] - new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j-n_added_nodes][i] - new.funs_sympy[j][i] = old.funs_sympy[j-n_added_nodes][i] - new.funs_name[j][i] = old.funs_name[j-n_added_nodes][i] - new.affine.data[j][i] = old.affine.data[j-n_added_nodes][i] - - self.symbolic_fun[l] = new - self.act_fun[l] = KANLayer(in_dim, out_dim + n_added_nodes, num=self.grid, k=self.k) - self.act_fun[l].mask *= 0. - - self.node_scale[l].data = torch.cat([torch.ones(n_added_nodes, device=self.device), self.node_scale[l].data]) - self.node_bias[l].data = torch.cat([torch.zeros(n_added_nodes, device=self.device), self.node_bias[l].data]) - self.subnode_scale[l].data = torch.cat([torch.ones(n_added_nodes, device=self.device), self.subnode_scale[l].data]) - self.subnode_bias[l].data = torch.cat([torch.zeros(n_added_nodes, device=self.device), self.subnode_bias[l].data]) - - - - if added_dim == 'in': - new = Symbolic_KANLayer(in_dim + n_added_nodes, out_dim) - old = self.symbolic_fun[l] - in_id = np.arange(in_dim + n_added_nodes) - out_id = np.arange(out_dim) - - for j in out_id: - for i in in_id: - new.fix_symbolic(i,j,'0') - new.mask += 1. - - for j in out_id: - for i in in_id: - if i > n_added_nodes-1: - new.funs[j][i] = old.funs[j][i-n_added_nodes] - new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i-n_added_nodes] - new.funs_sympy[j][i] = old.funs_sympy[j][i-n_added_nodes] - new.funs_name[j][i] = old.funs_name[j][i-n_added_nodes] - new.affine.data[j][i] = old.affine.data[j][i-n_added_nodes] - - self.symbolic_fun[l] = new - self.act_fun[l] = KANLayer(in_dim + n_added_nodes, out_dim, num=self.grid, k=self.k) - self.act_fun[l].mask *= 0. - - - else: - - if isinstance(mult_arity, int): - mult_arity = [mult_arity] * n_added_nodes - - if added_dim == 'out': - n_added_subnodes = np.sum(mult_arity) - new = Symbolic_KANLayer(in_dim, out_dim + n_added_subnodes) - old = self.symbolic_fun[l] - in_id = np.arange(in_dim) - out_id = np.arange(out_dim + n_added_nodes) - - for j in out_id: - for i in in_id: - new.fix_symbolic(i,j,'0') - new.mask += 1. - - for j in out_id: - for i in in_id: - if j < out_dim: - new.funs[j][i] = old.funs[j][i] - new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i] - new.funs_sympy[j][i] = old.funs_sympy[j][i] - new.funs_name[j][i] = old.funs_name[j][i] - new.affine.data[j][i] = old.affine.data[j][i] - - self.symbolic_fun[l] = new - self.act_fun[l] = KANLayer(in_dim, out_dim + n_added_subnodes, num=self.grid, k=self.k) - self.act_fun[l].mask *= 0. - - self.node_scale[l].data = torch.cat([self.node_scale[l].data, torch.ones(n_added_nodes, device=self.device)]) - self.node_bias[l].data = torch.cat([self.node_bias[l].data, torch.zeros(n_added_nodes, device=self.device)]) - self.subnode_scale[l].data = torch.cat([self.subnode_scale[l].data, torch.ones(n_added_subnodes, device=self.device)]) - self.subnode_bias[l].data = torch.cat([self.subnode_bias[l].data, torch.zeros(n_added_subnodes, device=self.device)]) - - if added_dim == 'in': - new = Symbolic_KANLayer(in_dim + n_added_nodes, out_dim) - old = self.symbolic_fun[l] - in_id = np.arange(in_dim + n_added_nodes) - out_id = np.arange(out_dim) - - for j in out_id: - for i in in_id: - new.fix_symbolic(i,j,'0') - new.mask += 1. - - for j in out_id: - for i in in_id: - if i < in_dim: - new.funs[j][i] = old.funs[j][i] - new.funs_avoid_singularity[j][i] = old.funs_avoid_singularity[j][i] - new.funs_sympy[j][i] = old.funs_sympy[j][i] - new.funs_name[j][i] = old.funs_name[j][i] - new.affine.data[j][i] = old.affine.data[j][i] - - self.symbolic_fun[l] = new - self.act_fun[l] = KANLayer(in_dim + n_added_nodes, out_dim, num=self.grid, k=self.k) - self.act_fun[l].mask *= 0. - - _expand(layer_id-1, n_added_nodes, sum_bool, mult_arity, added_dim='out') - _expand(layer_id, n_added_nodes, sum_bool, mult_arity, added_dim='in') - if sum_bool: - self.width[layer_id][0] += n_added_nodes - else: - if isinstance(mult_arity, int): - mult_arity = [mult_arity] * n_added_nodes - - self.width[layer_id][1] += n_added_nodes - self.mult_arity[layer_id] += mult_arity - - def perturb(self, mag=1.0, mode='non-intrusive'): - ''' - preturb a network. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb. - - Args: - ----- - mag : float - perturbation magnitude - mode : str - pertubatation mode, choices = {'non-intrusive', 'all', 'minimal'} - - Returns: - -------- - None - ''' - perturb_bool = {} - - if mode == 'all': - perturb_bool['aa_a'] = True - perturb_bool['aa_i'] = True - perturb_bool['ai'] = True - perturb_bool['ia'] = True - perturb_bool['ii'] = True - elif mode == 'non-intrusive': - perturb_bool['aa_a'] = False - perturb_bool['aa_i'] = False - perturb_bool['ai'] = True - perturb_bool['ia'] = False - perturb_bool['ii'] = True - elif mode == 'minimal': - perturb_bool['aa_a'] = True - perturb_bool['aa_i'] = False - perturb_bool['ai'] = False - perturb_bool['ia'] = False - perturb_bool['ii'] = False - else: - raise Exception('mode not recognized, valid modes are \'all\', \'non-intrusive\', \'minimal\'.') - - for l in range(self.depth): - funs_name = self.symbolic_fun[l].funs_name - for j in range(self.width_out[l+1]): - for i in range(self.width_in[l]): - out_array = list(np.array(self.symbolic_fun[l].funs_name)[j]) - in_array = list(np.array(self.symbolic_fun[l].funs_name)[:,i]) - out_active = len([i for i, x in enumerate(out_array) if x != "0"]) > 0 - in_active = len([i for i, x in enumerate(in_array) if x != "0"]) > 0 - dic = {True: 'a', False: 'i'} - edge_type = dic[in_active] + dic[out_active] - - if l < self.depth - 1 or mode != 'non-intrusive': - - if edge_type == 'aa': - if self.symbolic_fun[l].funs_name[j][i] == '0': - edge_type += '_i' - else: - edge_type += '_a' - - if perturb_bool[edge_type]: - self.act_fun[l].mask.data[i][j] = mag - - if l == self.depth - 1 and mode == 'non-intrusive': - - self.act_fun[l].mask.data[i][j] = torch.tensor(1.) - self.act_fun[l].scale_base.data[i][j] = torch.tensor(0.) - self.act_fun[l].scale_sp.data[i][j] = torch.tensor(0.) - - self.get_act(self.cache_data) - - self.log_history('perturb') - - - def module(self, start_layer, chain): - ''' - specify network modules - - Args: - ----- - start_layer : int - the earliest layer of the module - chain : str - specify neurons in the module - - Returns: - -------- - None - ''' - #chain = '[-1]->[-1,-2]->[-1]->[-1]' - groups = chain.split('->') - n_total_layers = len(groups)//2 - #start_layer = 0 - - for l in range(n_total_layers): - current_layer = cl = start_layer + l - id_in = [int(i) for i in groups[2*l][1:-1].split(',')] - id_out = [int(i) for i in groups[2*l+1][1:-1].split(',')] - - in_dim = self.width_in[cl] - out_dim = self.width_out[cl+1] - id_in_other = list(set(range(in_dim)) - set(id_in)) - id_out_other = list(set(range(out_dim)) - set(id_out)) - self.act_fun[cl].mask.data[np.ix_(id_in_other,id_out)] = 0. - self.act_fun[cl].mask.data[np.ix_(id_in,id_out_other)] = 0. - self.symbolic_fun[cl].mask.data[np.ix_(id_out,id_in_other)] = 0. - self.symbolic_fun[cl].mask.data[np.ix_(id_out_other,id_in)] = 0. - - self.log_history('module') - - def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False): - ''' - turn KAN into a tree - ''' - if x == None: - x = self.cache_data - plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test, verbose=verbose) - - - def speed(self, compile=False): - ''' - turn on KAN's speed mode - ''' - self.symbolic_enabled=False - self.save_act=False - self.auto_save=False - if compile == True: - return torch.compile(self) - else: - return self - - def get_act(self, x=None): - ''' - collect intermidate activations - ''' - if isinstance(x, dict): - x = x['train_input'] - if x == None: - if self.cache_data != None: - x = self.cache_data - else: - raise Exception("missing input data x") - save_act = self.save_act - self.save_act = True - self.forward(x) - self.save_act = save_act - - def get_fun(self, l, i, j): - ''' - get function (l,i,j) - ''' - inputs = self.spline_preacts[l][:,j,i].cpu().detach().numpy() - outputs = self.spline_postacts[l][:,j,i].cpu().detach().numpy() - # they are not ordered yet - rank = np.argsort(inputs) - inputs = inputs[rank] - outputs = outputs[rank] - plt.figure(figsize=(3,3)) - plt.plot(inputs, outputs, marker="o") - return inputs, outputs - - - def history(self, k='all'): - ''' - get history - ''' - with open(self.ckpt_path+'/history.txt', 'r') as f: - data = f.readlines() - n_line = len(data) - if k == 'all': - k = n_line - - data = data[-k:] - for line in data: - print(line[:-1]) - @property - def n_edge(self): - ''' - the number of active edges - ''' - depth = len(self.act_fun) - complexity = 0 - for l in range(depth): - complexity += torch.sum(self.act_fun[l].mask > 0.) - return complexity.item() - - def evaluate(self, dataset): - evaluation = {} - evaluation['test_loss'] = torch.sqrt(torch.mean((self.forward(dataset['test_input']) - dataset['test_label'])**2)).item() - evaluation['n_edge'] = self.n_edge - evaluation['n_grid'] = self.grid - # add other metrics (maybe accuracy) - return evaluation - - def swap(self, l, i1, i2, log_history=True): - - self.act_fun[l-1].swap(i1,i2,mode='out') - self.symbolic_fun[l-1].swap(i1,i2,mode='out') - self.act_fun[l].swap(i1,i2,mode='in') - self.symbolic_fun[l].swap(i1,i2,mode='in') - - def swap_(data, i1, i2): - data[i1], data[i2] = data[i2], data[i1] - - swap_(self.node_scale[l-1].data, i1, i2) - swap_(self.node_bias[l-1].data, i1, i2) - swap_(self.subnode_scale[l-1].data, i1, i2) - swap_(self.subnode_bias[l-1].data, i1, i2) - - if log_history: - self.log_history('swap') - - @property - def connection_cost(self): - - cc = 0. - for t in self.edge_scores: - - def get_coordinate(n): - return torch.linspace(0,1,steps=n+1, device=self.device)[:n] + 1/(2*n) - - in_dim = t.shape[0] - x_in = get_coordinate(in_dim) - - out_dim = t.shape[1] - x_out = get_coordinate(out_dim) - - dist = torch.abs(x_in[:,None] - x_out[None,:]) - cc += torch.sum(dist * t) - - return cc - - def auto_swap_l(self, l): - - num = self.width_in[1] - for i in range(num): - ccs = [] - for j in range(num): - self.swap(l,i,j,log_history=False) - self.get_act() - self.attribute() - cc = self.connection_cost.detach().clone() - ccs.append(cc) - self.swap(l,i,j,log_history=False) - j = torch.argmin(torch.tensor(ccs)) - self.swap(l,i,j,log_history=False) - - def auto_swap(self): - ''' - automatically swap neurons such as connection costs are minimized - ''' - depth = self.depth - for l in range(1, depth): - self.auto_swap_l(l) - - self.log_history('auto_swap') - -KAN = MultKAN diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/Symbolic_KANLayer.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/Symbolic_KANLayer.py deleted file mode 100644 index 3b199293c0f63200cc7a98fc67e103e74d1953fe..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/Symbolic_KANLayer.py +++ /dev/null @@ -1,270 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -import sympy -from .utils import * - - - -class Symbolic_KANLayer(nn.Module): - ''' - KANLayer class - - Attributes: - ----------- - in_dim : int - input dimension - out_dim : int - output dimension - funs : 2D array of torch functions (or lambda functions) - symbolic functions (torch) - funs_avoid_singularity : 2D array of torch functions (or lambda functions) with singularity avoiding - funs_name : 2D arry of str - names of symbolic functions - funs_sympy : 2D array of sympy functions (or lambda functions) - symbolic functions (sympy) - affine : 3D array of floats - affine transformations of inputs and outputs - ''' - def __init__(self, in_dim=3, out_dim=2, device='cpu'): - ''' - initialize a Symbolic_KANLayer (activation functions are initialized to be identity functions) - - Args: - ----- - in_dim : int - input dimension - out_dim : int - output dimension - device : str - device - - Returns: - -------- - self - - Example - ------- - >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=3) - >>> len(sb.funs), len(sb.funs[0]) - ''' - super(Symbolic_KANLayer, self).__init__() - self.out_dim = out_dim - self.in_dim = in_dim - self.mask = torch.nn.Parameter(torch.zeros(out_dim, in_dim, device=device)).requires_grad_(False) - # torch - self.funs = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)] - self.funs_avoid_singularity = [[lambda x, y_th: ((), x*0.) for i in range(self.in_dim)] for j in range(self.out_dim)] - # name - self.funs_name = [['0' for i in range(self.in_dim)] for j in range(self.out_dim)] - # sympy - self.funs_sympy = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)] - ### make funs_name the only parameter, and make others as the properties of funs_name? - - self.affine = torch.nn.Parameter(torch.zeros(out_dim, in_dim, 4, device=device)) - # c*f(a*x+b)+d - - self.device = device - self.to(device) - - def to(self, device): - ''' - move to device - ''' - super(Symbolic_KANLayer, self).to(device) - self.device = device - return self - - def forward(self, x, singularity_avoiding=False, y_th=10.): - ''' - forward - - Args: - ----- - x : 2D array - inputs, shape (batch, input dimension) - singularity_avoiding : bool - if True, funs_avoid_singularity is used; if False, funs is used. - y_th : float - the singularity threshold - - Returns: - -------- - y : 2D array - outputs, shape (batch, output dimension) - postacts : 3D array - activations after activation functions but before being summed on nodes - - Example - ------- - >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=5) - >>> x = torch.normal(0,1,size=(100,3)) - >>> y, postacts = sb(x) - >>> y.shape, postacts.shape - (torch.Size([100, 5]), torch.Size([100, 5, 3])) - ''' - - batch = x.shape[0] - postacts = [] - - for i in range(self.in_dim): - postacts_ = [] - for j in range(self.out_dim): - if singularity_avoiding: - xij = self.affine[j,i,2]*self.funs_avoid_singularity[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1], torch.tensor(y_th))[1]+self.affine[j,i,3] - else: - xij = self.affine[j,i,2]*self.funs[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1])+self.affine[j,i,3] - postacts_.append(self.mask[j][i]*xij) - postacts.append(torch.stack(postacts_)) - - postacts = torch.stack(postacts) - postacts = postacts.permute(2,1,0,3)[:,:,:,0] - y = torch.sum(postacts, dim=2) - - return y, postacts - - - def get_subset(self, in_id, out_id): - ''' - get a smaller Symbolic_KANLayer from a larger Symbolic_KANLayer (used for pruning) - - Args: - ----- - in_id : list - id of selected input neurons - out_id : list - id of selected output neurons - - Returns: - -------- - spb : Symbolic_KANLayer - - Example - ------- - >>> sb_large = Symbolic_KANLayer(in_dim=10, out_dim=10) - >>> sb_small = sb_large.get_subset([0,9],[1,2,3]) - >>> sb_small.in_dim, sb_small.out_dim - ''' - sbb = Symbolic_KANLayer(self.in_dim, self.out_dim, device=self.device) - sbb.in_dim = len(in_id) - sbb.out_dim = len(out_id) - sbb.mask.data = self.mask.data[out_id][:,in_id] - sbb.funs = [[self.funs[j][i] for i in in_id] for j in out_id] - sbb.funs_avoid_singularity = [[self.funs_avoid_singularity[j][i] for i in in_id] for j in out_id] - sbb.funs_sympy = [[self.funs_sympy[j][i] for i in in_id] for j in out_id] - sbb.funs_name = [[self.funs_name[j][i] for i in in_id] for j in out_id] - sbb.affine.data = self.affine.data[out_id][:,in_id] - return sbb - - - def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-10,10), b_range=(-10,10), verbose=True): - ''' - fix an activation function to be symbolic - - Args: - ----- - i : int - the id of input neuron - j : int - the id of output neuron - fun_name : str - the name of the symbolic functions - x : 1D array - preactivations - y : 1D array - postactivations - a_range : tuple - sweeping range of a - b_range : tuple - sweeping range of a - verbose : bool - print more information if True - - Returns: - -------- - r2 (coefficient of determination) - - Example 1 - --------- - >>> # when x & y are not provided. Affine parameters are set to a = 1, b = 0, c = 1, d = 0 - >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2) - >>> sb.fix_symbolic(2,1,'sin') - >>> print(sb.funs_name) - >>> print(sb.affine) - - Example 2 - --------- - >>> # when x & y are provided, fit_params() is called to find the best fit coefficients - >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2) - >>> batch = 100 - >>> x = torch.linspace(-1,1,steps=batch) - >>> noises = torch.normal(0,1,(batch,)) * 0.02 - >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises - >>> sb.fix_symbolic(2,1,'sin',x,y) - >>> print(sb.funs_name) - >>> print(sb.affine[1,2,:].data) - ''' - if isinstance(fun_name,str): - fun = SYMBOLIC_LIB[fun_name][0] - fun_sympy = SYMBOLIC_LIB[fun_name][1] - fun_avoid_singularity = SYMBOLIC_LIB[fun_name][3] - self.funs_sympy[j][i] = fun_sympy - self.funs_name[j][i] = fun_name - - if x == None or y == None: - #initialzie from just fun - self.funs[j][i] = fun - self.funs_avoid_singularity[j][i] = fun_avoid_singularity - if random == False: - self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device) - else: - self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1 - return None - else: - #initialize from x & y and fun - params, r2 = fit_params(x,y,fun, a_range=a_range, b_range=b_range, verbose=verbose, device=self.device) - self.funs[j][i] = fun - self.funs_avoid_singularity[j][i] = fun_avoid_singularity - self.affine.data[j][i] = params - return r2 - else: - # if fun_name itself is a function - fun = fun_name - fun_sympy = fun_name - self.funs_sympy[j][i] = fun_sympy - self.funs_name[j][i] = "anonymous" - - self.funs[j][i] = fun - self.funs_avoid_singularity[j][i] = fun - if random == False: - self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device) - else: - self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1 - return None - - def swap(self, i1, i2, mode='in'): - ''' - swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out') - ''' - with torch.no_grad(): - def swap_list_(data, i1, i2, mode='in'): - - if mode == 'in': - for j in range(self.out_dim): - data[j][i1], data[j][i2] = data[j][i2], data[j][i1] - - elif mode == 'out': - data[i1], data[i2] = data[i2], data[i1] - - def swap_(data, i1, i2, mode='in'): - if mode == 'in': - data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone() - - elif mode == 'out': - data[i1], data[i2] = data[i2].clone(), data[i1].clone() - - swap_list_(self.funs_name,i1,i2,mode) - swap_list_(self.funs_sympy,i1,i2,mode) - swap_list_(self.funs_avoid_singularity,i1,i2,mode) - swap_(self.affine.data,i1,i2,mode) - swap_(self.mask.data,i1,i2,mode) diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/__init__.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/__init__.py deleted file mode 100644 index 1ce0e47b21c80dcb7007e9742ec09b79b245dff9..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .MultKAN import * -from .utils import * -#torch.use_deterministic_algorithms(True) \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/assets/img/mult_symbol.png b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/assets/img/mult_symbol.png deleted file mode 100644 index 16d9960fbd6eb0f0e9ebf3628b3a6e2b13e56c95..0000000000000000000000000000000000000000 Binary files a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/assets/img/mult_symbol.png and /dev/null differ diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/assets/img/sum_symbol.png b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/assets/img/sum_symbol.png deleted file mode 100644 index 724084c5e6fd2554874fdecb77254cdb3b7cca13..0000000000000000000000000000000000000000 Binary files a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/assets/img/sum_symbol.png and /dev/null differ diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/compiler.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/compiler.py deleted file mode 100644 index c8014829e83b1a8a67687643d5b17d02286c2d3e..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/compiler.py +++ /dev/null @@ -1,498 +0,0 @@ -from sympy import * -import sympy -import numpy as np -from kan.MultKAN import MultKAN -import torch - -def next_nontrivial_operation(expr, scale=1, bias=0): - ''' - remove the affine part of an expression - - Args: - ----- - expr : sympy expression - scale : float - bias : float - - Returns: - -------- - expr : sympy expression - scale : float - bias : float - - Example - ------- - >>> from kan.compiler import * - >>> from sympy import * - >>> input_vars = a, b = symbols('a b') - >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402 - >>> next_nontrivial_operation(expression) - ''' - if expr.func == Add or expr.func == Mul: - n_arg = len(expr.args) - n_num = 0 - n_var_id = [] - n_num_id = [] - var_args = [] - for i in range(n_arg): - is_number = expr.args[i].is_number - n_num += is_number - if not is_number: - n_var_id.append(i) - var_args.append(expr.args[i]) - else: - n_num_id.append(i) - if n_num > 0: - # trivial - if expr.func == Add: - for i in range(n_num): - if i == 0: - bias = expr.args[n_num_id[i]] - else: - bias += expr.args[n_num_id[i]] - if expr.func == Mul: - for i in range(n_num): - if i == 0: - scale = expr.args[n_num_id[i]] - else: - scale *= expr.args[n_num_id[i]] - - return next_nontrivial_operation(expr.func(*var_args), scale, bias) - else: - return expr, scale, bias - else: - return expr, scale, bias - - -def expr2kan(input_variables, expr, grid=5, k=3, auto_save=False): - ''' - compile a symbolic formula to a MultKAN - - Args: - ----- - input_variables : a list of sympy symbols - expr : sympy expression - grid : int - the number of grid intervals - k : int - spline order - auto_save : bool - if auto_save = True, models are automatically saved - - Returns: - -------- - MultKAN - - Example - ------- - >>> from kan.compiler import * - >>> from sympy import * - >>> input_vars = a, b = symbols('a b') - >>> expression = exp(sin(pi*a) + b**2) - >>> model = kanpiler(input_vars, expression) - >>> x = torch.rand(100,2) * 2 - 1 - >>> model(x) - >>> model.plot() - ''' - class Node: - def __init__(self, expr, mult_bool, depth, scale, bias, parent=None, mult_arity=None): - self.expr = expr - self.mult_bool = mult_bool - if self.mult_bool: - self.mult_arity = mult_arity - self.depth = depth - - if len(Nodes) <= depth: - Nodes.append([]) - index = 0 - else: - index = len(Nodes[depth]) - - Nodes[depth].append(self) - - self.index = index - if parent == None: - self.parent_index = None - else: - self.parent_index = parent.index - self.child_index = [] - - # update parent's child_index - if parent != None: - parent.child_index.append(self.index) - - - self.scale = scale - self.bias = bias - - - class SubNode: - def __init__(self, expr, depth, scale, bias, parent=None): - self.expr = expr - self.depth = depth - - if len(SubNodes) <= depth: - SubNodes.append([]) - index = 0 - else: - index = len(SubNodes[depth]) - - SubNodes[depth].append(self) - - self.index = index - self.parent_index = None # shape: (2,) - self.child_index = [] # shape: (n, 2) - - # update parent's child_index - parent.child_index.append(self.index) - - self.scale = scale - self.bias = bias - - - class Connection: - def __init__(self, affine, fun, fun_name, parent=None, child=None, power_exponent=None): - # connection = activation function that connects a subnode to a node in the next layer node - self.affine = affine #[1,0,1,0] # (a,b,c,d) - self.fun = fun # y = c*fun(a*x+b)+d - self.fun_name = fun_name - self.parent_index = parent.index - self.depth = parent.depth - self.child_index = child.index - self.power_exponent = power_exponent # if fun == Pow - Connections[(self.depth,self.parent_index,self.child_index)] = self - - def create_node(expr, parent=None, n_layer=None): - #print('before', expr) - expr, scale, bias = next_nontrivial_operation(expr) - #print('after', expr) - if parent == None: - depth = 0 - else: - depth = parent.depth - - - if expr.func == Mul: - mult_arity = len(expr.args) - node = Node(expr, True, depth, scale, bias, parent=parent, mult_arity=mult_arity) - # create mult_arity SubNodes, + 1 - for i in range(mult_arity): - # create SubNode - expr_i, scale, bias = next_nontrivial_operation(expr.args[i]) - subnode = SubNode(expr_i, node.depth+1, scale, bias, parent=node) - if expr_i.func == Add: - for j in range(len(expr_i.args)): - expr_ij, scale, bias = next_nontrivial_operation(expr_i.args[j]) - # expr_ij is impossible to be Add, should be Mul or 1D - if expr_ij.func == Mul: - #print(expr_ij) - # create a node with expr_ij - new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer) - # create a connection which is a linear function - c = Connection([1,0,float(scale),float(bias)], lambda x: x, 'x', parent=subnode, child=new_node) - - elif expr_ij.func == Symbol: - #print(expr_ij) - new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer) - c = Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node) - - else: - # 1D function case - # create a node with expr_ij.args[0] - new_node = create_node(expr_ij.args[0], parent=subnode, n_layer=n_layer) - # create 1D function expr_ij.func - if expr_ij.func == Pow: - power_exponent = expr_ij.args[1] - else: - power_exponent = None - Connection([1,0,float(scale),float(bias)], expr_ij.func, fun_name = expr_ij.func, parent=subnode, child=new_node, power_exponent=power_exponent) - - - elif expr_i.func == Mul: - # create a node with expr_i - new_node = create_node(expr_i, parent=subnode, n_layer=n_layer) - # create 1D function, linear - Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node) - - elif expr_i.func == Symbol: - new_node = create_node(expr_i, parent=subnode, n_layer=n_layer) - Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node) - - else: - # 1D functions - # create a node with expr_i.args[0] - new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer) - # create 1D function expr_i.func - if expr_i.func == Pow: - power_exponent = expr_i.args[1] - else: - power_exponent = None - Connection([1,0,1,0], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent) - - elif expr.func == Add: - - node = Node(expr, False, depth, scale, bias, parent=parent) - subnode = SubNode(expr, node.depth+1, 1, 0, parent=node) - - for i in range(len(expr.args)): - expr_i, scale, bias = next_nontrivial_operation(expr.args[i]) - if expr_i.func == Mul: - # create a node with expr_i - new_node = create_node(expr_i, parent=subnode, n_layer=n_layer) - # create a connection which is a linear function - Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node) - - elif expr_i.func == Symbol: - new_node = create_node(expr_i, parent=subnode, n_layer=n_layer) - Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node) - - else: - # 1D function case - # create a node with expr_ij.args[0] - new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer) - # create 1D function expr_i.func - if expr_i.func == Pow: - power_exponent = expr_i.args[1] - else: - power_exponent = None - Connection([1,0,float(scale),float(bias)], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent) - - elif expr.func == Symbol: - # expr.func is a symbol (one of input variables) - if n_layer == None: - node = Node(expr, False, depth, scale, bias, parent=parent) - else: - node = Node(expr, False, depth, scale, bias, parent=parent) - return_node = node - for i in range(n_layer - depth): - subnode = SubNode(expr, node.depth+1, 1, 0, parent=node) - node = Node(expr, False, subnode.depth, 1, 0, parent=subnode) - Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=node) - node = return_node - - Start_Nodes.append(node) - - else: - # expr.func is 1D function - #print(expr, scale, bias) - node = Node(expr, False, depth, scale, bias, parent=parent) - expr_i, scale, bias = next_nontrivial_operation(expr.args[0]) - subnode = SubNode(expr_i, node.depth+1, 1, 0, parent=node) - # create a node with expr_i.args[0] - new_node = create_node(expr.args[0], parent=subnode, n_layer=n_layer) - # create 1D function expr_i.func - if expr.func == Pow: - power_exponent = expr.args[1] - else: - power_exponent = None - Connection([1,0,1,0], expr.func, fun_name = expr.func, parent=subnode, child=new_node, power_exponent=power_exponent) - - return node - - Nodes = [[]] - SubNodes = [[]] - Connections = {} - Start_Nodes = [] - - create_node(expr, n_layer=None) - - n_layer = len(Nodes) - 1 - - Nodes = [[]] - SubNodes = [[]] - Connections = {} - Start_Nodes = [] - - create_node(expr, n_layer=n_layer) - - # move affine parameters in leaf nodes to connections - for node in Start_Nodes: - c = Connections[(node.depth,node.parent_index,node.index)] - c.affine[0] = float(node.scale) - c.affine[1] = float(node.bias) - node.scale = 1. - node.bias = 0. - - #input_variables = symbol - node2var = [] - for node in Start_Nodes: - for i in range(len(input_variables)): - if node.expr == input_variables[i]: - node2var.append(i) - - # Nodes - n_mult = [] - n_sum = [] - for layer in Nodes: - n_mult.append(0) - n_sum.append(0) - for node in layer: - if node.mult_bool == True: - n_mult[-1] += 1 - else: - n_sum[-1] += 1 - - # depth - n_layer = len(Nodes) - 1 - - # converter - # input tree node id, output kan node id (distinguish sum and mult node) - # input tree subnode id, output tree subnode id - # node id - subnode_index_convert = {} - node_index_convert = {} - connection_index_convert = {} - mult_arities = [] - for layer_id in range(n_layer+1): - mult_arity = [] - i_sum = 0 - i_mult = 0 - for i in range(len(Nodes[layer_id])): - node = Nodes[layer_id][i] - if node.mult_bool == True: - kan_node_id = n_sum[layer_id] + i_mult - arity = len(node.child_index) - for i in range(arity): - subnode = SubNodes[node.depth+1][node.child_index[i]] - kan_subnode_id = n_sum[layer_id] + np.sum(mult_arity) + i - subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id)) - i_mult += 1 - mult_arity.append(arity) - else: - kan_node_id = i_sum - if len(node.child_index) > 0: - subnode = SubNodes[node.depth+1][node.child_index[0]] - kan_subnode_id = i_sum - subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id)) - i_sum += 1 - - if layer_id == n_layer: - # input layer - node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(node2var[kan_node_id])) - else: - node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(kan_node_id)) - - # node: depth (node.depth -> n_layer - node.depth) - # width (node.index -> kan_node_id) - # subnode: depth (subnode.depth -> n_layer - subnode.depth) - # width (subnote.index -> kan_subnode_id) - mult_arities.append(mult_arity) - - for index in list(Connections.keys()): - depth, subnode_id, node_id = index - # to int(n_layer-depth), - _, kan_subnode_id = subnode_index_convert[(depth, subnode_id)] - _, kan_node_id = node_index_convert[(depth, node_id)] - connection_index_convert[(depth, subnode_id, node_id)] = (n_layer-depth, kan_subnode_id, kan_node_id) - - - n_sum.reverse() - n_mult.reverse() - mult_arities.reverse() - - width = [[n_sum[i], n_mult[i]] for i in range(len(n_sum))] - width[0][0] = len(input_variables) - - # allow pass in other parameters (probably as a dictionary) in sf2kan, including grid k etc. - model = MultKAN(width=width, mult_arity=mult_arities, grid=grid, k=k, auto_save=auto_save) - - # clean the graph - for l in range(model.depth): - for i in range(model.width_in[l]): - for j in range(model.width_out[l+1]): - model.fix_symbolic(l,i,j,'0',fit_params_bool=False) - - # Nodes - Nodes_flat = [x for xs in Nodes for x in xs] - - self = model - - for node in Nodes_flat: - node_depth = node.depth - node_index = node.index - kan_node_depth, kan_node_index = node_index_convert[(node_depth,node_index)] - #print(kan_node_depth, kan_node_index) - if kan_node_depth > 0: - self.node_scale[kan_node_depth-1].data[kan_node_index] = float(node.scale) - self.node_bias[kan_node_depth-1].data[kan_node_index] = float(node.bias) - - - # SubNodes - SubNodes_flat = [x for xs in SubNodes for x in xs] - - for subnode in SubNodes_flat: - subnode_depth = subnode.depth - subnode_index = subnode.index - kan_subnode_depth, kan_subnode_index = subnode_index_convert[(subnode_depth,subnode_index)] - #print(kan_subnode_depth, kan_subnode_index) - self.subnode_scale[kan_subnode_depth].data[kan_subnode_index] = float(subnode.scale) - self.subnode_bias[kan_subnode_depth].data[kan_subnode_index] = float(subnode.bias) - - # Connections - Connections_flat = list(Connections.values()) - - for connection in Connections_flat: - c_depth = connection.depth - c_j = connection.parent_index - c_i = connection.child_index - kc_depth, kc_j, kc_i = connection_index_convert[(c_depth, c_j, c_i)] - - # get symbolic fun_name - fun_name = connection.fun_name - #if fun_name == Pow: - # print(connection.power_exponent) - - if fun_name == 'x': - kfun_name = 'x' - elif fun_name == exp: - kfun_name = 'exp' - elif fun_name == sin: - kfun_name = 'sin' - elif fun_name == cos: - kfun_name = 'cos' - elif fun_name == tan: - kfun_name = 'tan' - elif fun_name == sqrt: - kfun_name = 'sqrt' - elif fun_name == log: - kfun_name = 'log' - elif fun_name == tanh: - kfun_name = 'tanh' - elif fun_name == asin: - kfun_name = 'arcsin' - elif fun_name == acos: - kfun_name = 'arccos' - elif fun_name == atan: - kfun_name = 'arctan' - elif fun_name == atanh: - kfun_name = 'arctanh' - elif fun_name == sign: - kfun_name = 'sgn' - elif fun_name == Pow: - alpha = connection.power_exponent - if alpha == Rational(1,2): - kfun_name = 'x^0.5' - elif alpha == - Rational(1,2): - kfun_name = '1/x^0.5' - elif alpha == Rational(3,2): - kfun_name = 'x^1.5' - else: - alpha = int(connection.power_exponent) - if alpha > 0: - if alpha == 1: - kfun_name = 'x' - else: - kfun_name = f'x^{alpha}' - else: - if alpha == -1: - kfun_name = '1/x' - else: - kfun_name = f'1/x^{-alpha}' - - model.fix_symbolic(kc_depth, kc_i, kc_j, kfun_name, fit_params_bool=False) - model.symbolic_fun[kc_depth].affine.data.reshape(self.width_out[kc_depth+1], self.width_in[kc_depth], 4)[kc_j][kc_i] = torch.tensor(connection.affine) - - return model - - -sf2kan = kanpiler = expr2kan \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/experiment.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/experiment.py deleted file mode 100644 index 9ab9e9de3197e10e2ce2c57b6db361cbf18628bb..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/experiment.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -from .MultKAN import * - - -def runner1(width, dataset, grids=[5,10,20], steps=20, lamb=0.001, prune_round=3, refine_round=3, edge_th=1e-2, node_th=1e-2, metrics=None, seed=1): - - result = {} - result['test_loss'] = [] - result['c'] = [] - result['G'] = [] - result['id'] = [] - if metrics != None: - for i in range(len(metrics)): - result[metrics[i].__name__] = [] - - def collect(evaluation): - result['test_loss'].append(evaluation['test_loss']) - result['c'].append(evaluation['n_edge']) - result['G'].append(evaluation['n_grid']) - result['id'].append(f'{model.round}.{model.state_id}') - if metrics != None: - for i in range(len(metrics)): - result[metrics[i].__name__].append(metrics[i](model, dataset).item()) - - for i in range(prune_round): - # train and prune - if i == 0: - model = KAN(width=width, grid=grids[0], seed=seed) - else: - model = model.rewind(f'{i-1}.{2*i}') - - model.fit(dataset, steps=steps, lamb=lamb) - model = model.prune(edge_th=edge_th, node_th=node_th) - evaluation = model.evaluate(dataset) - collect(evaluation) - - for j in range(refine_round): - model = model.refine(grids[j]) - model.fit(dataset, steps=steps) - evaluation = model.evaluate(dataset) - collect(evaluation) - - for key in list(result.keys()): - result[key] = np.array(result[key]) - - return result - - -def pareto_frontier(x,y): - - pf_id = np.where(np.sum((x[:,None] <= x[None,:]) * (y[:,None] <= y[None,:]), axis=0) == 1)[0] - x_pf = x[pf_id] - y_pf = y[pf_id] - - return x_pf, y_pf, pf_id \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/feynman.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/feynman.py deleted file mode 100644 index 6cc55e96ff947391434abf9168b5d6792a3d7119..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/feynman.py +++ /dev/null @@ -1,739 +0,0 @@ -from sympy import * -import torch - - -def get_feynman_dataset(name): - - global symbols - - tpi = torch.tensor(torch.pi) - - if name == 'test': - symbol = x, y = symbols('x, y') - expr = (x+y) * sin(exp(2*y)) - f = lambda x: (x[:,[0]] + x[:,[1]])*torch.sin(torch.exp(2*x[:,[1]])) - ranges = [-1,1] - - if name == 'I.6.20a' or name == 1: - symbol = theta = symbols('theta') - symbol = [symbol] - expr = exp(-theta**2/2)/sqrt(2*pi) - f = lambda x: torch.exp(-x[:,[0]]**2/2)/torch.sqrt(2*tpi) - ranges = [[-3,3]] - - if name == 'I.6.20' or name == 2: - symbol = theta, sigma = symbols('theta sigma') - expr = exp(-theta**2/(2*sigma**2))/sqrt(2*pi*sigma**2) - f = lambda x: torch.exp(-x[:,[0]]**2/(2*x[:,[1]]**2))/torch.sqrt(2*tpi*x[:,[1]]**2) - ranges = [[-1,1],[0.5,2]] - - if name == 'I.6.20b' or name == 3: - symbol = theta, theta1, sigma = symbols('theta theta1 sigma') - expr = exp(-(theta-theta1)**2/(2*sigma**2))/sqrt(2*pi*sigma**2) - f = lambda x: torch.exp(-(x[:,[0]]-x[:,[1]])**2/(2*x[:,[2]]**2))/torch.sqrt(2*tpi*x[:,[2]]**2) - ranges = [[-1.5,1.5],[-1.5,1.5],[0.5,2]] - - if name == 'I.8.4' or name == 4: - symbol = x1, x2, y1, y2 = symbols('x1 x2 y1 y2') - expr = sqrt((x2-x1)**2+(y2-y1)**2) - f = lambda x: torch.sqrt((x[:,[1]]-x[:,[0]])**2+(x[:,[3]]-x[:,[2]])**2) - ranges = [[-1,1],[-1,1],[-1,1],[-1,1]] - - if name == 'I.9.18' or name == 5: - symbol = G, m1, m2, x1, x2, y1, y2, z1, z2 = symbols('G m1 m2 x1 x2 y1 y2 z1 z2') - expr = G*m1*m2/((x2-x1)**2+(y2-y1)**2+(z2-z1)**2) - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/((x[:,[3]]-x[:,[4]])**2+(x[:,[5]]-x[:,[6]])**2+(x[:,[7]]-x[:,[8]])**2) - ranges = [[-1,1],[-1,1],[-1,1],[-1,-0.5],[0.5,1],[-1,-0.5],[0.5,1],[-1,-0.5],[0.5,1]] - - if name == 'I.10.7' or name == 6: - symbol = m0, v, c = symbols('m0 v c') - expr = m0/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[0,1],[0,1],[1,2]] - - if name == 'I.11.19' or name == 7: - symbol = x1, y1, x2, y2, x3, y3 = symbols('x1 y1 x2 y2 x3 y3') - expr = x1*y1 + x2*y2 + x3*y3 - f = lambda x: x[:,[0]]*x[:,[1]] + x[:,[2]]*x[:,[3]] + x[:,[4]]*x[:,[5]] - ranges = [-1,1] - - if name == 'I.12.1' or name == 8: - symbol = mu, Nn = symbols('mu N_n') - expr = mu * Nn - f = lambda x: x[:,[0]]*x[:,[1]] - ranges = [-1,1] - - if name == 'I.12.2' or name == 9: - symbol = q1, q2, eps, r = symbols('q1 q2 epsilon r') - expr = q1*q2/(4*pi*eps*r**2) - f = lambda x: x[:,[0]]*x[:,[1]]/(4*tpi*x[:,[2]]*x[:,[3]]**2) - ranges = [[-1,1],[-1,1],[0.5,2],[0.5,2]] - - if name == 'I.12.4' or name == 10: - symbol = q1, eps, r = symbols('q1 epsilon r') - expr = q1/(4*pi*eps*r**2) - f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]]**2) - ranges = [[-1,1],[0.5,2],[0.5,2]] - - if name == 'I.12.5' or name == 11: - symbol = q2, Ef = symbols('q2, E_f') - expr = q2*Ef - f = lambda x: x[:,[0]]*x[:,[1]] - ranges = [-1,1] - - if name == 'I.12.11' or name == 12: - symbol = q, Ef, B, v, theta = symbols('q E_f B v theta') - expr = q*(Ef + B*v*sin(theta)) - f = lambda x: x[:,[0]]*(x[:,[1]]+x[:,[2]]*x[:,[3]]*torch.sin(x[:,[4]])) - ranges = [[-1,1],[-1,1],[-1,1],[-1,1],[0,2*tpi]] - - if name == 'I.13.4' or name == 13: - symbol = m, v, u, w = symbols('m u v w') - expr = 1/2*m*(v**2+u**2+w**2) - f = lambda x: 1/2*x[:,[0]]*(x[:,[1]]**2+x[:,[2]]**2+x[:,[3]]**2) - ranges = [[-1,1],[-1,1],[-1,1],[-1,1]] - - if name == 'I.13.12' or name == 14: - symbol = G, m1, m2, r1, r2 = symbols('G m1 m2 r1 r2') - expr = G*m1*m2*(1/r2-1/r1) - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*(1/x[:,[4]]-1/x[:,[3]]) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'I.14.3' or name == 15: - symbol = m, g, z = symbols('m g z') - expr = m*g*z - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]] - ranges = [[0,1],[0,1],[-1,1]] - - if name == 'I.14.4' or name == 16: - symbol = ks, x = symbols('k_s x') - expr = 1/2*ks*x**2 - f = lambda x: 1/2*x[:,[0]]*x[:,[1]]**2 - ranges = [[0,1],[-1,1]] - - if name == 'I.15.3x' or name == 17: - symbol = x, u, t, c = symbols('x u t c') - expr = (x-u*t)/sqrt(1-u**2/c**2) - f = lambda x: (x[:,[0]] - x[:,[1]]*x[:,[2]])/torch.sqrt(1-x[:,[1]]**2/x[:,[3]]**2) - ranges = [[-1,1],[-1,1],[-1,1],[1,2]] - - if name == 'I.15.3t' or name == 18: - symbol = t, u, x, c = symbols('t u x c') - expr = (t-u*x/c**2)/sqrt(1-u**2/c**2) - f = lambda x: (x[:,[0]] - x[:,[1]]*x[:,[2]]/x[:,[3]]**2)/torch.sqrt(1-x[:,[1]]**2/x[:,[3]]**2) - ranges = [[-1,1],[-1,1],[-1,1],[1,2]] - - if name == 'I.15.10' or name == 19: - symbol = m0, v, c = symbols('m0 v c') - expr = m0*v/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]*x[:,[1]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[-1,1],[-0.9,0.9],[1.1,2]] - - if name == 'I.16.6' or name == 20: - symbol = u, v, c = symbols('u v c') - expr = (u+v)/(1+u*v/c**2) - f = lambda x: x[:,[0]]*x[:,[1]]/(1+x[:,[0]]*x[:,[1]]/x[:,[2]]**2) - ranges = [[-0.8,0.8],[-0.8,0.8],[1,2]] - - if name == 'I.18.4' or name == 21: - symbol = m1, r1, m2, r2 = symbols('m1 r1 m2 r2') - expr = (m1*r1+m2*r2)/(m1+m2) - f = lambda x: (x[:,[0]]*x[:,[1]]+x[:,[2]]*x[:,[3]])/(x[:,[0]]+x[:,[2]]) - ranges = [[0.5,1],[-1,1],[0.5,1],[-1,1]] - - if name == 'I.18.4' or name == 22: - symbol = r, F, theta = symbols('r F theta') - expr = r*F*sin(theta) - f = lambda x: x[:,[0]]*x[:,[1]]*torch.sin(x[:,[2]]) - ranges = [[-1,1],[-1,1],[0,2*tpi]] - - if name == 'I.18.16' or name == 23: - symbol = m, r, v, theta = symbols('m r v theta') - expr = m*r*v*sin(theta) - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*torch.sin(x[:,[3]]) - ranges = [[-1,1],[-1,1],[-1,1],[0,2*tpi]] - - if name == 'I.24.6' or name == 24: - symbol = m, omega, omega0, x = symbols('m omega omega_0 x') - expr = 1/4*m*(omega**2+omega0**2)*x**2 - f = lambda x: 1/4*x[:,[0]]*(x[:,[1]]**2+x[:,[2]]**2)*x[:,[3]]**2 - ranges = [[0,1],[-1,1],[-1,1],[-1,1]] - - if name == 'I.25.13' or name == 25: - symbol = q, C = symbols('q C') - expr = q/C - f = lambda x: x[:,[0]]/x[:,[1]] - ranges = [[-1,1],[0.5,2]] - - if name == 'I.26.2' or name == 26: - symbol = n, theta2 = symbols('n theta2') - expr = asin(n*sin(theta2)) - f = lambda x: torch.arcsin(x[:,[0]]*torch.sin(x[:,[1]])) - ranges = [[0,0.99],[0,2*tpi]] - - if name == 'I.27.6' or name == 27: - symbol = d1, d2, n = symbols('d1 d2 n') - expr = 1/(1/d1+n/d2) - f = lambda x: 1/(1/x[:,[0]]+x[:,[2]]/x[:,[1]]) - ranges = [[0.5,2],[1,2],[0.5,2]] - - if name == 'I.29.4' or name == 28: - symbol = omega, c = symbols('omega c') - expr = omega/c - f = lambda x: x[:,[0]]/x[:,[1]] - ranges = [[0,1],[0.5,2]] - - if name == 'I.29.16' or name == 29: - symbol = x1, x2, theta1, theta2 = symbols('x1 x2 theta1 theta2') - expr = sqrt(x1**2+x2**2-2*x1*x2*cos(theta1-theta2)) - f = lambda x: torch.sqrt(x[:,[0]]**2+x[:,[1]]**2-2*x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]]-x[:,[3]])) - ranges = [[-1,1],[-1,1],[0,2*tpi],[0,2*tpi]] - - if name == 'I.30.3' or name == 30: - symbol = I0, n, theta = symbols('I_0 n theta') - expr = I0 * sin(n*theta/2)**2 / sin(theta/2) ** 2 - f = lambda x: x[:,[0]] * torch.sin(x[:,[1]]*x[:,[2]]/2)**2 / torch.sin(x[:,[2]]/2)**2 - ranges = [[0,1],[0,4],[0.4*tpi,1.6*tpi]] - - if name == 'I.30.5' or name == 31: - symbol = lamb, n, d = symbols('lambda n d') - expr = asin(lamb/(n*d)) - f = lambda x: torch.arcsin(x[:,[0]]/(x[:,[1]]*x[:,[2]])) - ranges = [[-1,1],[1,1.5],[1,1.5]] - - if name == 'I.32.5' or name == 32: - symbol = q, a, eps, c = symbols('q a epsilon c') - expr = q**2*a**2/(eps*c**3) - f = lambda x: x[:,[0]]**2*x[:,[1]]**2/(x[:,[2]]*x[:,[3]]**3) - ranges = [[-1,1],[-1,1],[0.5,2],[0.5,2]] - - if name == 'I.32.17' or name == 33: - symbol = eps, c, Ef, r, omega, omega0 = symbols('epsilon c E_f r omega omega_0') - expr = nsimplify((1/2*eps*c*Ef**2)*(8*pi*r**2/3)*(omega**4/(omega**2-omega0**2)**2)) - f = lambda x: (1/2*x[:,[0]]*x[:,[1]]*x[:,[2]]**2)*(8*tpi*x[:,[3]]**2/3)*(x[:,[4]]**4/(x[:,[4]]**2-x[:,[5]]**2)**2) - ranges = [[0,1],[0,1],[-1,1],[0,1],[0,1],[1,2]] - - if name == 'I.34.8' or name == 34: - symbol = q, V, B, p = symbols('q V B p') - expr = q*V*B/p - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[-1,1],[-1,1],[-1,1],[0.5,2]] - - if name == 'I.34.10' or name == 35: - symbol = omega0, v, c = symbols('omega_0 v c') - expr = omega0/(1-v/c) - f = lambda x: x[:,[0]]/(1-x[:,[1]]/x[:,[2]]) - ranges = [[0,1],[0,0.9],[1.1,2]] - - if name == 'I.34.14' or name == 36: - symbol = omega0, v, c = symbols('omega_0 v c') - expr = omega0 * (1+v/c)/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]*(1+x[:,[1]]/x[:,[2]])/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[0,1],[-0.9,0.9],[1.1,2]] - - if name == 'I.34.27' or name == 37: - symbol = hbar, omega = symbols('hbar omega') - expr = hbar * omega - f = lambda x: x[:,[0]]*x[:,[1]] - ranges = [[-1,1],[-1,1]] - - if name == 'I.37.4' or name == 38: - symbol = I1, I2, delta = symbols('I_1 I_2 delta') - expr = I1 + I2 + 2*sqrt(I1*I2)*cos(delta) - f = lambda x: x[:,[0]] + x[:,[1]] + 2*torch.sqrt(x[:,[0]]*x[:,[1]])*torch.cos(x[:,[2]]) - ranges = [[0.1,1],[0.1,1],[0,2*tpi]] - - if name == 'I.38.12' or name == 39: - symbol = eps, hbar, m, q = symbols('epsilon hbar m q') - expr = 4*pi*eps*hbar**2/(m*q**2) - f = lambda x: 4*tpi*x[:,[0]]*x[:,[1]]**2/(x[:,[2]]*x[:,[3]]**2) - ranges = [[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'I.39.10' or name == 40: - symbol = pF, V = symbols('p_F V') - expr = 3/2 * pF * V - f = lambda x: 3/2 * x[:,[0]] * x[:,[1]] - ranges = [[0,1],[0,1]] - - if name == 'I.39.11' or name == 41: - symbol = gamma, pF, V = symbols('gamma p_F V') - expr = pF * V/(gamma - 1) - f = lambda x: 1/(x[:,[0]]-1) * x[:,[1]] * x[:,[2]] - ranges = [[1.5,3],[0,1],[0,1]] - - if name == 'I.39.22' or name == 42: - symbol = n, kb, T, V = symbols('n k_b T V') - expr = n*kb*T/V - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'I.40.1' or name == 43: - symbol = n0, m, g, x, kb, T = symbols('n_0 m g x k_b T') - expr = n0 * exp(-m*g*x/(kb*T)) - f = lambda x: x[:,[0]] * torch.exp(-x[:,[1]]*x[:,[2]]*x[:,[3]]/(x[:,[4]]*x[:,[5]])) - ranges = [[0,1],[-1,1],[-1,1],[-1,1],[1,2],[1,2]] - - if name == 'I.41.16' or name == 44: - symbol = hbar, omega, c, kb, T = symbols('hbar omega c k_b T') - expr = hbar * omega**3/(pi**2*c**2*(exp(hbar*omega/(kb*T))-1)) - f = lambda x: x[:,[0]]*x[:,[1]]**3/(tpi**2*x[:,[2]]**2*(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[3]]*x[:,[4]]))-1)) - ranges = [[0.5,1],[0.5,1],[0.5,2],[0.5,2],[0.5,2]] - - if name == 'I.43.16' or name == 45: - symbol = mu, q, Ve, d = symbols('mu q V_e d') - expr = mu*q*Ve/d - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'I.43.31' or name == 46: - symbol = mu, kb, T = symbols('mu k_b T') - expr = mu*kb*T - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]] - ranges = [[0,1],[0,1],[0,1]] - - if name == 'I.43.43' or name == 47: - symbol = gamma, kb, v, A = symbols('gamma k_b v A') - expr = kb*v/A/(gamma-1) - f = lambda x: 1/(x[:,[0]]-1)*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[1.5,3],[0,1],[0,1],[0.5,2]] - - if name == 'I.44.4' or name == 48: - symbol = n, kb, T, V1, V2 = symbols('n k_b T V_1 V_2') - expr = n*kb*T*log(V2/V1) - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*torch.log(x[:,[4]]/x[:,[3]]) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'I.47.23' or name == 49: - symbol = gamma, p, rho = symbols('gamma p rho') - expr = sqrt(gamma*p/rho) - f = lambda x: torch.sqrt(x[:,[0]]*x[:,[1]]/x[:,[2]]) - ranges = [[0.1,1],[0.1,1],[0.5,2]] - - if name == 'I.48.20' or name == 50: - symbol = m, v, c = symbols('m v c') - expr = m*c**2/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]*x[:,[2]]**2/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[0,1],[-0.9,0.9],[1.1,2]] - - if name == 'I.50.26' or name == 51: - symbol = x1, alpha, omega, t = symbols('x_1 alpha omega t') - expr = x1*(cos(omega*t)+alpha*cos(omega*t)**2) - f = lambda x: x[:,[0]]*(torch.cos(x[:,[2]]*x[:,[3]])+x[:,[1]]*torch.cos(x[:,[2]]*x[:,[3]])**2) - ranges = [[0,1],[0,1],[0,2*tpi],[0,1]] - - if name == 'II.2.42' or name == 52: - symbol = kappa, T1, T2, A, d = symbols('kappa T_1 T_2 A d') - expr = kappa*(T2-T1)*A/d - f = lambda x: x[:,[0]]*(x[:,[2]]-x[:,[1]])*x[:,[3]]/x[:,[4]] - ranges = [[0,1],[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'II.3.24' or name == 53: - symbol = P, r = symbols('P r') - expr = P/(4*pi*r**2) - f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]**2) - ranges = [[0,1],[0.5,2]] - - if name == 'II.4.23' or name == 54: - symbol = q, eps, r = symbols('q epsilon r') - expr = q/(4*pi*eps*r) - f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]]) - ranges = [[0,1],[0.5,2],[0.5,2]] - - if name == 'II.6.11' or name == 55: - symbol = eps, pd, theta, r = symbols('epsilon p_d theta r') - expr = 1/(4*pi*eps)*pd*cos(theta)/r**2 - f = lambda x: 1/(4*tpi*x[:,[0]])*x[:,[1]]*torch.cos(x[:,[2]])/x[:,[3]]**2 - ranges = [[0.5,2],[0,1],[0,2*tpi],[0.5,2]] - - if name == 'II.6.15a' or name == 56: - symbol = eps, pd, z, x, y, r = symbols('epsilon p_d z x y r') - expr = 3/(4*pi*eps)*pd*z/r**5*sqrt(x**2+y**2) - f = lambda x: 3/(4*tpi*x[:,[0]])*x[:,[1]]*x[:,[2]]/x[:,[5]]**5*torch.sqrt(x[:,[3]]**2+x[:,[4]]**2) - ranges = [[0.5,2],[0,1],[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'II.6.15b' or name == 57: - symbol = eps, pd, r, theta = symbols('epsilon p_d r theta') - expr = 3/(4*pi*eps)*pd/r**3*cos(theta)*sin(theta) - f = lambda x: 3/(4*tpi*x[:,[0]])*x[:,[1]]/x[:,[2]]**3*torch.cos(x[:,[3]])*torch.sin(x[:,[3]]) - ranges = [[0.5,2],[0,1],[0.5,2],[0,2*tpi]] - - if name == 'II.8.7' or name == 58: - symbol = q, eps, d = symbols('q epsilon d') - expr = 3/5*q**2/(4*pi*eps*d) - f = lambda x: 3/5*x[:,[0]]**2/(4*tpi*x[:,[1]]*x[:,[2]]) - ranges = [[0,1],[0.5,2],[0.5,2]] - - if name == 'II.8.31' or name == 59: - symbol = eps, Ef = symbols('epsilon E_f') - expr = 1/2*eps*Ef**2 - f = lambda x: 1/2*x[:,[0]]*x[:,[1]]**2 - ranges = [[0,1],[0,1]] - - if name == 'I.10.9' or name == 60: - symbol = sigma, eps, chi = symbols('sigma epsilon chi') - expr = sigma/eps/(1+chi) - f = lambda x: x[:,[0]]/x[:,[1]]/(1+x[:,[2]]) - ranges = [[0,1],[0.5,2],[0,1]] - - if name == 'II.11.3' or name == 61: - symbol = q, Ef, m, omega0, omega = symbols('q E_f m omega_o omega') - expr = q*Ef/(m*(omega0**2-omega**2)) - f = lambda x: x[:,[0]]*x[:,[1]]/(x[:,[2]]*(x[:,[3]]**2-x[:,[4]]**2)) - ranges = [[0,1],[0,1],[0.5,2],[1.5,3],[0,1]] - - if name == 'II.11.17' or name == 62: - symbol = n0, pd, Ef, theta, kb, T = symbols('n_0 p_d E_f theta k_b T') - expr = n0*(1+pd*Ef*cos(theta)/(kb*T)) - f = lambda x: x[:,[0]]*(1+x[:,[1]]*x[:,[2]]*torch.cos(x[:,[3]])/(x[:,[4]]*x[:,[5]])) - ranges = [[0,1],[-1,1],[-1,1],[0,2*tpi],[0.5,2],[0.5,2]] - - - if name == 'II.11.20' or name == 63: - symbol = n, pd, Ef, kb, T = symbols('n p_d E_f k_b T') - expr = n*pd**2*Ef/(3*kb*T) - f = lambda x: x[:,[0]]*x[:,[1]]**2*x[:,[2]]/(3*x[:,[3]]*x[:,[4]]) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'II.11.27' or name == 64: - symbol = n, alpha, eps, Ef = symbols('n alpha epsilon E_f') - expr = n*alpha/(1-n*alpha/3)*eps*Ef - f = lambda x: x[:,[0]]*x[:,[1]]/(1-x[:,[0]]*x[:,[1]]/3)*x[:,[2]]*x[:,[3]] - ranges = [[0,1],[0,2],[0,1],[0,1]] - - if name == 'II.11.28' or name == 65: - symbol = n, alpha = symbols('n alpha') - expr = 1 + n*alpha/(1-n*alpha/3) - f = lambda x: 1 + x[:,[0]]*x[:,[1]]/(1-x[:,[0]]*x[:,[1]]/3) - ranges = [[0,1],[0,2]] - - if name == 'II.13.17' or name == 66: - symbol = eps, c, l, r = symbols('epsilon c l r') - expr = 1/(4*pi*eps*c**2)*(2*l/r) - f = lambda x: 1/(4*tpi*x[:,[0]]*x[:,[1]]**2)*(2*x[:,[2]]/x[:,[3]]) - ranges = [[0.5,2],[0.5,2],[0,1],[0.5,2]] - - if name == 'II.13.23' or name == 67: - symbol = rho, v, c = symbols('rho v c') - expr = rho/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[0,1],[0,1],[1,2]] - - if name == 'II.13.34' or name == 68: - symbol = rho, v, c = symbols('rho v c') - expr = rho*v/sqrt(1-v**2/c**2) - f = lambda x: x[:,[0]]*x[:,[1]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) - ranges = [[0,1],[0,1],[1,2]] - - if name == 'II.15.4' or name == 69: - symbol = muM, B, theta = symbols('mu_M B theta') - expr = - muM * B * cos(theta) - f = lambda x: - x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]]) - ranges = [[0,1],[0,1],[0,2*tpi]] - - if name == 'II.15.5' or name == 70: - symbol = pd, Ef, theta = symbols('p_d E_f theta') - expr = - pd * Ef * cos(theta) - f = lambda x: - x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]]) - ranges = [[0,1],[0,1],[0,2*tpi]] - - if name == 'II.21.32' or name == 71: - symbol = q, eps, r, v, c = symbols('q epsilon r v c') - expr = q/(4*pi*eps*r*(1-v/c)) - f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]]*(1-x[:,[3]]/x[:,[4]])) - ranges = [[0,1],[0.5,2],[0.5,2],[0,1],[1,2]] - - if name == 'II.24.17' or name == 72: - symbol = omega, c, d = symbols('omega c d') - expr = sqrt(omega**2/c**2-pi**2/d**2) - f = lambda x: torch.sqrt(x[:,[0]]**2/x[:,[1]]**2-tpi**2/x[:,[2]]**2) - ranges = [[1,1.5],[0.75,1],[1*tpi,1.5*tpi]] - - if name == 'II.27.16' or name == 73: - symbol = eps, c, Ef = symbols('epsilon c E_f') - expr = eps * c * Ef**2 - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]**2 - ranges = [[0,1],[0,1],[-1,1]] - - if name == 'II.27.18' or name == 74: - symbol = eps, Ef = symbols('epsilon E_f') - expr = eps * Ef**2 - f = lambda x: x[:,[0]]*x[:,[1]]**2 - ranges = [[0,1],[-1,1]] - - if name == 'II.34.2a' or name == 75: - symbol = q, v, r = symbols('q v r') - expr = q*v/(2*pi*r) - f = lambda x: x[:,[0]]*x[:,[1]]/(2*tpi*x[:,[2]]) - ranges = [[0,1],[0,1],[0.5,2]] - - if name == 'II.34.2' or name == 76: - symbol = q, v, r = symbols('q v r') - expr = q*v*r/2 - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/2 - ranges = [[0,1],[0,1],[0,1]] - - if name == 'II.34.11' or name == 77: - symbol = g, q, B, m = symbols('g q B m') - expr = g*q*B/(2*m) - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/(2*x[:,[3]]) - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'II.34.29a' or name == 78: - symbol = q, h, m = symbols('q h m') - expr = q*h/(4*pi*m) - f = lambda x: x[:,[0]]*x[:,[1]]/(4*tpi*x[:,[2]]) - ranges = [[0,1],[0,1],[0.5,2]] - - if name == 'II.34.29b' or name == 79: - symbol = g, mu, B, J, hbar = symbols('g mu B J hbar') - expr = g*mu*B*J/hbar - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*x[:,[3]]/x[:,[4]] - ranges = [[0,1],[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'II.35.18' or name == 80: - symbol = n0, mu, B, kb, T = symbols('n0 mu B k_b T') - expr = n0/(exp(mu*B/(kb*T))+exp(-mu*B/(kb*T))) - f = lambda x: x[:,[0]]/(torch.exp(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))+torch.exp(-x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'II.35.21' or name == 81: - symbol = n, mu, B, kb, T = symbols('n mu B k_b T') - expr = n*mu*tanh(mu*B/(kb*T)) - f = lambda x: x[:,[0]]*x[:,[1]]*torch.tanh(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]])) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'II.36.38' or name == 82: - symbol = mu, B, kb, T, alpha, M, eps, c = symbols('mu B k_b T alpha M epsilon c') - expr = mu*B/(kb*T) + mu*alpha*M/(eps*c**2*kb*T) - f = lambda x: x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]) + x[:,[0]]*x[:,[4]]*x[:,[5]]/(x[:,[6]]*x[:,[7]]**2*x[:,[2]]*x[:,[3]]) - ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'II.37.1' or name == 83: - symbol = mu, chi, B = symbols('mu chi B') - expr = mu*(1+chi)*B - f = lambda x: x[:,[0]]*(1+x[:,[1]])*x[:,[2]] - ranges = [[0,1],[0,1],[0,1]] - - if name == 'II.38.3' or name == 84: - symbol = Y, A, x, d = symbols('Y A x d') - expr = Y*A*x/d - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'II.38.14' or name == 85: - symbol = Y, sigma = symbols('Y sigma') - expr = Y/(2*(1+sigma)) - f = lambda x: x[:,[0]]/(2*(1+x[:,[1]])) - ranges = [[0,1],[0,1]] - - if name == 'III.4.32' or name == 86: - symbol = hbar, omega, kb, T = symbols('hbar omega k_b T') - expr = 1/(exp(hbar*omega/(kb*T))-1) - f = lambda x: 1/(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]))-1) - ranges = [[0.5,1],[0.5,1],[0.5,2],[0.5,2]] - - if name == 'III.4.33' or name == 87: - symbol = hbar, omega, kb, T = symbols('hbar omega k_b T') - expr = hbar*omega/(exp(hbar*omega/(kb*T))-1) - f = lambda x: x[:,[0]]*x[:,[1]]/(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]))-1) - ranges = [[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'III.7.38' or name == 88: - symbol = mu, B, hbar = symbols('mu B hbar') - expr = 2*mu*B/hbar - f = lambda x: 2*x[:,[0]]*x[:,[1]]/x[:,[2]] - ranges = [[0,1],[0,1],[0.5,2]] - - if name == 'III.8.54' or name == 89: - symbol = E, t, hbar = symbols('E t hbar') - expr = sin(E*t/hbar)**2 - f = lambda x: torch.sin(x[:,[0]]*x[:,[1]]/x[:,[2]])**2 - ranges = [[0,2*tpi],[0,1],[0.5,2]] - - if name == 'III.9.52' or name == 90: - symbol = pd, Ef, t, hbar, omega, omega0 = symbols('p_d E_f t hbar omega omega_0') - expr = pd*Ef*t/hbar*sin((omega-omega0)*t/2)**2/((omega-omega0)*t/2)**2 - f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]*torch.sin((x[:,[4]]-x[:,[5]])*x[:,[2]]/2)**2/((x[:,[4]]-x[:,[5]])*x[:,[2]]/2)**2 - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0,tpi],[0,tpi]] - - if name == 'III.10.19' or name == 91: - symbol = mu, Bx, By, Bz = symbols('mu B_x B_y B_z') - expr = mu*sqrt(Bx**2+By**2+Bz**2) - f = lambda x: x[:,[0]]*torch.sqrt(x[:,[1]]**2+x[:,[2]]**2+x[:,[3]]**2) - ranges = [[0,1],[0,1],[0,1],[0,1]] - - if name == 'III.12.43' or name == 92: - symbol = n, hbar = symbols('n hbar') - expr = n * hbar - f = lambda x: x[:,[0]]*x[:,[1]] - ranges = [[0,1],[0,1]] - - if name == 'III.13.18' or name == 93: - symbol = E, d, k, hbar = symbols('E d k hbar') - expr = 2*E*d**2*k/hbar - f = lambda x: 2*x[:,[0]]*x[:,[1]]**2*x[:,[2]]/x[:,[3]] - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'III.14.14' or name == 94: - symbol = I0, q, Ve, kb, T = symbols('I_0 q V_e k_b T') - expr = I0 * (exp(q*Ve/(kb*T))-1) - f = lambda x: x[:,[0]]*(torch.exp(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))-1) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]] - - if name == 'III.15.12' or name == 95: - symbol = U, k, d = symbols('U k d') - expr = 2*U*(1-cos(k*d)) - f = lambda x: 2*x[:,[0]]*(1-torch.cos(x[:,[1]]*x[:,[2]])) - ranges = [[0,1],[0,2*tpi],[0,1]] - - if name == 'III.15.14' or name == 96: - symbol = hbar, E, d = symbols('hbar E d') - expr = hbar**2/(2*E*d**2) - f = lambda x: x[:,[0]]**2/(2*x[:,[1]]*x[:,[2]]**2) - ranges = [[0,1],[0.5,2],[0.5,2]] - - if name == 'III.15.27' or name == 97: - symbol = alpha, n, d = symbols('alpha n d') - expr = 2*pi*alpha/(n*d) - f = lambda x: 2*tpi*x[:,[0]]/(x[:,[1]]*x[:,[2]]) - ranges = [[0,1],[0.5,2],[0.5,2]] - - if name == 'III.17.37' or name == 98: - symbol = beta, alpha, theta = symbols('beta alpha theta') - expr = beta * (1+alpha*cos(theta)) - f = lambda x: x[:,[0]]*(1+x[:,[1]]*torch.cos(x[:,[2]])) - ranges = [[0,1],[0,1],[0,2*tpi]] - - if name == 'III.19.51' or name == 99: - symbol = m, q, eps, hbar, n = symbols('m q epsilon hbar n') - expr = - m * q**4/(2*(4*pi*eps)**2*hbar**2)*1/n**2 - f = lambda x: - x[:,[0]]*x[:,[1]]**4/(2*(4*tpi*x[:,[2]])**2*x[:,[3]]**2)*1/x[:,[4]]**2 - ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2]] - - if name == 'III.21.20' or name == 100: - symbol = rho, q, A, m = symbols('rho q A m') - expr = - rho*q*A/m - f = lambda x: - x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]] - ranges = [[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'Rutherforld scattering' or name == 101: - symbol = Z1, Z2, alpha, hbar, c, E, theta = symbols('Z_1 Z_2 alpha hbar c E theta') - expr = (Z1*Z2*alpha*hbar*c/(4*E*sin(theta/2)**2))**2 - f = lambda x: (x[:,[0]]*x[:,[1]]*x[:,[2]]*x[:,[3]]*x[:,[4]]/(4*x[:,[5]]*torch.sin(x[:,[6]]/2)**2))**2 - ranges = [[0,1],[0,1],[0,1],[0,1],[0,1],[0.5,2],[0.1*tpi,0.9*tpi]] - - if name == 'Friedman equation' or name == 102: - symbol = G, rho, kf, c, af = symbols('G rho k_f c a_f') - expr = sqrt(8*pi*G/3*rho-kf*c**2/af**2) - f = lambda x: torch.sqrt(8*tpi*x[:,[0]]/3*x[:,[1]] - x[:,[2]]*x[:,[3]]**2/x[:,[4]]**2) - ranges = [[1,2],[1,2],[0,1],[0,1],[1,2]] - - if name == 'Compton scattering' or name == 103: - symbol = E, m, c, theta = symbols('E m c theta') - expr = E/(1+E/(m*c**2)*(1-cos(theta))) - f = lambda x: x[:,[0]]/(1+x[:,[0]]/(x[:,[1]]*x[:,[2]]**2)*(1-torch.cos(x[:,[3]]))) - ranges = [[0,1],[0.5,2],[0.5,2],[0,2*tpi]] - - if name == 'Radiated gravitational wave power' or name == 104: - symbol = G, c, m1, m2, r = symbols('G c m_1 m_2 r') - expr = -32/5*G**4/c**5*(m1*m2)**2*(m1+m2)/r**5 - f = lambda x: -32/5*x[:,[0]]**4/x[:,[1]]**5*(x[:,[2]]*x[:,[3]])**2*(x[:,[2]]+x[:,[3]])/x[:,[4]]**5 - ranges = [[0,1],[0.5,2],[0,1],[0,1],[0.5,2]] - - if name == 'Relativistic aberration' or name == 105: - symbol = theta2, v, c = symbols('theta_2 v c') - expr = acos((cos(theta2)-v/c)/(1-v/c*cos(theta2))) - f = lambda x: torch.arccos((torch.cos(x[:,[0]])-x[:,[1]]/x[:,[2]])/(1-x[:,[1]]/x[:,[2]]*torch.cos(x[:,[0]]))) - ranges = [[0,tpi],[0,1],[1,2]] - - if name == 'N-slit diffraction' or name == 106: - symbol = I0, alpha, delta, N = symbols('I_0 alpha delta N') - expr = I0 * (sin(alpha/2)/(alpha/2)*sin(N*delta/2)/sin(delta/2))**2 - f = lambda x: x[:,[0]] * (torch.sin(x[:,[1]]/2)/(x[:,[1]]/2)*torch.sin(x[:,[3]]*x[:,[2]]/2)/torch.sin(x[:,[2]]/2))**2 - ranges = [[0,1],[0.1*tpi,0.9*tpi],[0.1*tpi,0.9*tpi],[0.5,1]] - - if name == 'Goldstein 3.16' or name == 107: - symbol = m, E, U, L, r = symbols('m E U L r') - expr = sqrt(2/m*(E-U-L**2/(2*m*r**2))) - f = lambda x: torch.sqrt(2/x[:,[0]]*(x[:,[1]]-x[:,[2]]-x[:,[3]]**2/(2*x[:,[0]]*x[:,[4]]**2))) - ranges = [[1,2],[2,3],[0,1],[0,1],[1,2]] - - if name == 'Goldstein 3.55' or name == 108: - symbol = m, kG, L, E, theta1, theta2 = symbols('m k_G L E theta_1 theta_2') - expr = m*kG/L**2*(1+sqrt(1+2*E*L**2/(m*kG**2))*cos(theta1-theta2)) - f = lambda x: x[:,[0]]*x[:,[1]]/x[:,[2]]**2*(1+torch.sqrt(1+2*x[:,[3]]*x[:,[2]]**2/(x[:,[0]]*x[:,[1]]**2))*torch.cos(x[:,[4]]-x[:,[5]])) - ranges = [[0.5,2],[0.5,2],[0.5,2],[0,1],[0,2*tpi],[0,2*tpi]] - - if name == 'Goldstein 3.64 (ellipse)' or name == 109: - symbol = d, alpha, theta1, theta2 = symbols('d alpha theta_1 theta_2') - expr = d*(1-alpha**2)/(1+alpha*cos(theta2-theta1)) - f = lambda x: x[:,[0]]*(1-x[:,[1]]**2)/(1+x[:,[1]]*torch.cos(x[:,[2]]-x[:,[3]])) - ranges = [[0,1],[0,0.9],[0,2*tpi],[0,2*tpi]] - - if name == 'Goldstein 3.74 (Kepler)' or name == 110: - symbol = d, G, m1, m2 = symbols('d G m_1 m_2') - expr = 2*pi*d**(3/2)/sqrt(G*(m1+m2)) - f = lambda x: 2*tpi*x[:,[0]]**(3/2)/torch.sqrt(x[:,[1]]*(x[:,[2]]+x[:,[3]])) - ranges = [[0,1],[0.5,2],[0.5,2],[0.5,2]] - - if name == 'Goldstein 3.99' or name == 111: - symbol = eps, E, L, m, Z1, Z2, q = symbols('epsilon E L m Z_1 Z_2 q') - expr = sqrt(1+2*eps**2*E*L**2/(m*(Z1*Z2*q**2)**2)) - f = lambda x: torch.sqrt(1+2*x[:,[0]]**2*x[:,[1]]*x[:,[2]]**2/(x[:,[3]]*(x[:,[4]]*x[:,[5]]*x[:,[6]]**2)**2)) - ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2],[0.5,2]] - - if name == 'Goldstein 8.56' or name == 112: - symbol = p, q, A, c, m, Ve = symbols('p q A c m V_e') - expr = sqrt((p-q*A)**2*c**2+m**2*c**4) + q*Ve - f = lambda x: torch.sqrt((x[:,[0]]-x[:,[1]]*x[:,[2]])**2*x[:,[3]]**2+x[:,[4]]**2*x[:,[3]]**4) + x[:,[1]]*x[:,[5]] - ranges = [0,1] - - if name == 'Goldstein 12.80' or name == 113: - symbol = m, p, omega, x, alpha, y = symbols('m p omega x alpha y') - expr = 1/(2*m)*(p**2+m**2*omega**2*x**2*(1+alpha*y/x)) - f = lambda x: 1/(2*x[:,[0]]) * (x[:,[1]]**2+x[:,[0]]**2*x[:,[2]]**2*x[:,[3]]**2*(1+x[:,[4]]*x[:,[3]]/x[:,[5]])) - ranges = [[0.5,2],[0,1],[0,1],[0,1],[0,1],[0.5,2]] - - if name == 'Jackson 2.11' or name == 114: - symbol = q, eps, y, Ve, d = symbols('q epsilon y V_e d') - expr = q/(4*pi*eps*y**2)*(4*pi*eps*Ve*d-q*d*y**3/(y**2-d**2)**2) - f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,x[:,[2]]]**2)*(4*tpi*x[:,[1]]*x[:,[3]]*x[:,[4]]-x[:,[0]]*x[:,[4]]*x[:,[2]]**3/(x[:,[2]]**2-x[:,[4]]**2)**2) - ranges = [[0,1],[0.5,2],[1,2],[0,1],[0,1]] - - if name == 'Jackson 3.45' or name == 115: - symbol = q, r, d, alpha = symbols('q r d alpha') - expr = q/sqrt(r**2+d**2-2*d*r*cos(alpha)) - f = lambda x: x[:,[0]]/torch.sqrt(x[:,[1]]**2+x[:,[2]]**2-2*x[:,[1]]*x[:,[2]]*torch.cos(x[:,[3]])) - ranges = [[0,1],[0,1],[0,1],[0,2*tpi]] - - if name == 'Jackson 4.60' or name == 116: - symbol = Ef, theta, alpha, d, r = symbols('E_f theta alpha d r') - expr = Ef * cos(theta) * ((alpha-1)/(alpha+2) * d**3/r**2 - r) - f = lambda x: x[:,[0]] * torch.cos(x[:,[1]]) * ((x[:,[2]]-1)/(x[:,[2]]+2) * x[:,[3]]**3/x[:,[4]]**2 - x[:,[4]]) - ranges = [[0,1],[0,2*tpi],[0,2],[0,1],[0.5,2]] - - if name == 'Jackson 11.38 (Doppler)' or name == 117: - symbol = omega, v, c, theta = symbols('omega v c theta') - expr = sqrt(1-v**2/c**2)/(1+v/c*cos(theta))*omega - f = lambda x: torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)/(1+x[:,[1]]/x[:,[2]]*torch.cos(x[:,[3]]))*x[:,[0]] - ranges = [[0,1],[0,1],[1,2],[0,2*tpi]] - - if name == 'Weinberg 15.2.1' or name == 118: - symbol = G, c, kf, af, H = symbols('G c k_f a_f H') - expr = 3/(8*pi*G)*(c**2*kf/af**2+H**2) - f = lambda x: 3/(8*tpi*x[:,[0]])*(x[:,[1]]**2*x[:,[2]]/x[:,[3]]**2+x[:,[4]]**2) - ranges = [[0.5,2],[0,1],[0,1],[0.5,2],[0,1]] - - if name == 'Weinberg 15.2.2' or name == 119: - symbol = G, c, kf, af, H, alpha = symbols('G c k_f a_f H alpha') - expr = -1/(8*pi*G)*(c**4*kf/af**2+c**2*H**2*(1-2*alpha)) - f = lambda x: -1/(8*tpi*x[:,[0]])*(x[:,[1]]**4*x[:,[2]]/x[:,[3]]**2 + x[:,[1]]**2*x[:,[4]]**2*(1-2*x[:,[5]])) - ranges = [[0.5,2],[0,1],[0,1],[0.5,2],[0,1],[0,1]] - - if name == 'Schwarz 13.132 (Klein-Nishina)' or name == 120: - symbol = alpha, hbar, m, c, omega0, omega, theta = symbols('alpha hbar m c omega_0 omega theta') - expr = pi*alpha**2*hbar**2/m**2/c**2*(omega0/omega)**2*(omega0/omega+omega/omega0-sin(theta)**2) - f = lambda x: tpi*x[:,[0]]**2*x[:,[1]]**2/x[:,[2]]**2/x[:,[3]]**2*(x[:,[4]]/x[:,[5]])**2*(x[:,[4]]/x[:,[5]]+x[:,[5]]/x[:,[4]]-torch.sin(x[:,[6]])**2) - ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2],[0.5,2],[0,2*tpi]] - - return symbol, expr, f, ranges \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/hypothesis.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/hypothesis.py deleted file mode 100644 index 4850f509849c9efe21437b30b9ca2f220bf38181..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/hypothesis.py +++ /dev/null @@ -1,695 +0,0 @@ -import numpy as np -import torch -from sklearn.linear_model import LinearRegression -from sympy.utilities.lambdify import lambdify -from sklearn.cluster import AgglomerativeClustering -from .utils import batch_jacobian, batch_hessian -from functools import reduce -from kan.utils import batch_jacobian, batch_hessian -import copy -import matplotlib.pyplot as plt -import sympy -from sympy.printing import latex - - -def detect_separability(model, x, mode='add', score_th=1e-2, res_th=1e-2, n_clusters=None, bias=0., verbose=False): - ''' - detect function separability - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - mode : str - mode = 'add' or mode = 'mul' - score_th : float - threshold of score - res_th : float - threshold of residue - n_clusters : None or int - the number of clusters - bias : float - bias (for multiplicative separability) - verbose : bool - - Returns: - -------- - results (dictionary) - - Example1 - -------- - >>> from kan.hypothesis import * - >>> model = lambda x: x[:,[0]] ** 2 + torch.exp(x[:,[1]]+x[:,[2]]) - >>> x = torch.normal(0,1,size=(100,3)) - >>> detect_separability(model, x, mode='add') - - Example2 - -------- - >>> from kan.hypothesis import * - >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]]) - >>> x = torch.normal(0,1,size=(100,3)) - >>> detect_separability(model, x, mode='mul') - ''' - results = {} - - if mode == 'add': - hessian = batch_hessian(model, x) - elif mode == 'mul': - compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F) - hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x) - - std = torch.std(x, dim=0) - hessian_normalized = hessian * std[None,:] * std[:,None] - score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0] - results['hessian'] = score_mat - - dist_hard = (score_mat < score_th).float() - - if isinstance(n_clusters, int): - n_cluster_try = [n_clusters, n_clusters] - elif isinstance(n_clusters, list): - n_cluster_try = n_clusters - else: - n_cluster_try = [1,x.shape[1]] - - n_cluster_try = list(range(n_cluster_try[0], n_cluster_try[1]+1)) - - for n_cluster in n_cluster_try: - - clustering = AgglomerativeClustering( - metric='precomputed', - n_clusters=n_cluster, - linkage='complete', - ).fit(dist_hard) - - labels = clustering.labels_ - - groups = [list(np.where(labels == i)[0]) for i in range(n_cluster)] - blocks = [torch.sum(score_mat[groups[i]][:,groups[i]]) for i in range(n_cluster)] - block_sum = torch.sum(torch.stack(blocks)) - total_sum = torch.sum(score_mat) - residual_sum = total_sum - block_sum - residual_ratio = residual_sum / total_sum - - if verbose == True: - print(f'n_group={n_cluster}, residual_ratio={residual_ratio}') - - if residual_ratio < res_th: - results['n_groups'] = n_cluster - results['labels'] = list(labels) - results['groups'] = groups - - if results['n_groups'] > 1: - print(f'{mode} separability detected') - else: - print(f'{mode} separability not detected') - - return results - - -def batch_grad_normgrad(model, x, group, create_graph=False): - # x in shape (Batch, Length) - group_A = group - group_B = list(set(range(x.shape[1])) - set(group)) - - def jac(x): - input_grad = batch_jacobian(model, x, create_graph=True) - input_grad_A = input_grad[:,group_A] - norm = torch.norm(input_grad_A, dim=1, keepdim=True) + 1e-6 - input_grad_A_normalized = input_grad_A/norm - return input_grad_A_normalized - - def _jac_sum(x): - return jac(x).sum(dim=0) - - return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)[:,:,group_B] - - -def get_dependence(model, x, group): - group_A = group - group_B = list(set(range(x.shape[1])) - set(group)) - grad_normgrad = batch_grad_normgrad(model, x, group=group) - std = torch.std(x, dim=0) - dependence = grad_normgrad * std[None,group_A,None] * std[None,None,group_B] - dependence = torch.median(torch.abs(dependence), dim=0)[0] - return dependence - -def test_symmetry(model, x, group, dependence_th=1e-3): - ''' - detect function separability - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - group : a list of indices - dependence_th : float - threshold of dependence - - Returns: - -------- - bool - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]]) - >>> x = torch.normal(0,1,size=(100,3)) - >>> print(test_symmetry(model, x, [1,2])) # True - >>> print(test_symmetry(model, x, [0,2])) # False - ''' - if len(group) == x.shape[1] or len(group) == 0: - return True - - dependence = get_dependence(model, x, group) - max_dependence = torch.max(dependence) - return max_dependence < dependence_th - - -def test_separability(model, x, groups, mode='add', threshold=1e-2, bias=0): - ''' - test function separability - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - mode : str - mode = 'add' or mode = 'mul' - score_th : float - threshold of score - res_th : float - threshold of residue - bias : float - bias (for multiplicative separability) - verbose : bool - - Returns: - -------- - bool - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]]) - >>> x = torch.normal(0,1,size=(100,3)) - >>> print(test_separability(model, x, [[0],[1,2]], mode='mul')) # True - >>> print(test_separability(model, x, [[0],[1,2]], mode='add')) # False - ''' - if mode == 'add': - hessian = batch_hessian(model, x) - elif mode == 'mul': - compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F) - hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x) - - std = torch.std(x, dim=0) - hessian_normalized = hessian * std[None,:] * std[:,None] - score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0] - - sep_bool = True - - # internal test - n_groups = len(groups) - for i in range(n_groups): - for j in range(i+1, n_groups): - sep_bool *= torch.max(score_mat[groups[i]][:,groups[j]]) < threshold - - # external test - group_id = [x for xs in groups for x in xs] - nongroup_id = list(set(range(x.shape[1])) - set(group_id)) - if len(nongroup_id) > 0 and len(group_id) > 0: - sep_bool *= torch.max(score_mat[group_id][:,nongroup_id]) < threshold - - return sep_bool - -def test_general_separability(model, x, groups, threshold=1e-2): - ''' - test function separability - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - mode : str - mode = 'add' or mode = 'mul' - score_th : float - threshold of score - res_th : float - threshold of residue - bias : float - bias (for multiplicative separability) - verbose : bool - - Returns: - -------- - bool - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]**2+x[:,[2]]**2)**2 - >>> x = torch.normal(0,1,size=(100,3)) - >>> print(test_general_separability(model, x, [[1],[0,2]])) # False - >>> print(test_general_separability(model, x, [[0],[1,2]])) # True - ''' - grad = batch_jacobian(model, x) - - gensep_bool = True - - n_groups = len(groups) - for i in range(n_groups): - for j in range(i+1,n_groups): - group_A = groups[i] - group_B = groups[j] - for member_A in group_A: - for member_B in group_B: - def func(x): - grad = batch_jacobian(model, x, create_graph=True) - return grad[:,[member_B]]/grad[:,[member_A]] - # test if func is multiplicative separable - gensep_bool *= test_separability(func, x, groups, mode='mul', threshold=threshold) - return gensep_bool - - -def get_molecule(model, x, sym_th=1e-3, verbose=True): - ''' - how variables are combined hierarchically - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - sym_th : float - threshold of symmetry - verbose : bool - - Returns: - -------- - list - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2 - >>> x = torch.normal(0,1,size=(100,8)) - >>> get_molecule(model, x, verbose=False) - [[[0], [1], [2], [3], [4], [5], [6], [7]], - [[0, 1], [2, 3], [4, 5], [6, 7]], - [[0, 1, 2, 3], [4, 5, 6, 7]], - [[0, 1, 2, 3, 4, 5, 6, 7]]] - ''' - n = x.shape[1] - atoms = [[i] for i in range(n)] - molecules = [] - moleculess = [copy.deepcopy(atoms)] - already_full = False - n_layer = 0 - last_n_molecule = n - - while True: - - - pointer = 0 - current_molecule = [] - remove_atoms = [] - n_atom = 0 - - while len(atoms) > 0: - - # assemble molecule - atom = atoms[pointer] - if verbose: - print(current_molecule) - print(atom) - - if len(current_molecule) == 0: - full = False - current_molecule += atom - remove_atoms.append(atom) - n_atom += 1 - else: - # try assemble the atom to the molecule - if len(current_molecule+atom) == x.shape[1] and already_full == False and n_atom > 1 and n_layer > 0: - full = True - already_full = True - else: - full = False - if test_symmetry(model, x, current_molecule+atom, dependence_th=sym_th): - current_molecule += atom - remove_atoms.append(atom) - n_atom += 1 - - pointer += 1 - - if pointer == len(atoms) or full: - molecules.append(current_molecule) - if full: - molecules.append(atom) - remove_atoms.append(atom) - # remove molecules from atoms - for atom in remove_atoms: - atoms.remove(atom) - current_molecule = [] - remove_atoms = [] - pointer = 0 - - # if not making progress, terminate - if len(molecules) == last_n_molecule: - def flatten(xss): - return [x for xs in xss for x in xs] - moleculess.append([flatten(molecules)]) - break - else: - moleculess.append(copy.deepcopy(molecules)) - - last_n_molecule = len(molecules) - - if len(molecules) == 1: - break - - atoms = molecules - molecules = [] - - n_layer += 1 - - #print(n_layer, atoms) - - - # sort - depth = len(moleculess) - 1 - - for l in list(range(depth,0,-1)): - - molecules_sorted = [] - molecules_l = moleculess[l] - molecules_lm1 = moleculess[l-1] - - - for molecule_l in molecules_l: - start = 0 - for i in range(1,len(molecule_l)+1): - if molecule_l[start:i] in molecules_lm1: - - molecules_sorted.append(molecule_l[start:i]) - start = i - - moleculess[l-1] = molecules_sorted - - return moleculess - - -def get_tree_node(model, x, moleculess, sep_th=1e-2, skip_test=True): - ''' - get tree nodes - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - sep_th : float - threshold of separability - skip_test : bool - if True, don't test the property of each module (to save time) - - Returns: - -------- - arities : list of numbers - properties : list of strings - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2 - >>> x = torch.normal(0,1,size=(100,8)) - >>> moleculess = get_molecule(model, x, verbose=False) - >>> get_tree_node(model, x, moleculess, skip_test=False) - ''' - arities = [] - properties = [] - - depth = len(moleculess) - 1 - - for l in range(depth): - molecules_l = copy.deepcopy(moleculess[l]) - molecules_lp1 = copy.deepcopy(moleculess[l+1]) - arity_l = [] - property_l = [] - - for molecule in molecules_lp1: - start = 0 - arity = 0 - groups = [] - for i in range(1,len(molecule)+1): - if molecule[start:i] in molecules_l: - groups.append(molecule[start:i]) - start = i - arity += 1 - arity_l.append(arity) - - if arity == 1: - property = 'Id' - else: - property = '' - # test property - if skip_test: - gensep_bool = False - else: - gensep_bool = test_general_separability(model, x, groups, threshold=sep_th) - - if gensep_bool: - property = 'GS' - if l == depth - 1: - if skip_test: - add_bool = False - mul_bool = False - else: - add_bool = test_separability(model, x, groups, mode='add', threshold=sep_th) - mul_bool = test_separability(model, x, groups, mode='mul', threshold=sep_th) - if add_bool: - property = 'Add' - if mul_bool: - property = 'Mul' - - - property_l.append(property) - - - arities.append(arity_l) - properties.append(property_l) - - return arities, properties - - -def plot_tree(model, x, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False): - ''' - get tree graph - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - in_var : list of symbols - input variables - style : str - 'tree' or 'box' - sym_th : float - threshold of symmetry - sep_th : float - threshold of separability - skip_sep_test : bool - if True, don't test the property of each module (to save time) - verbose : bool - - Returns: - -------- - a tree graph - - Example - ------- - >>> from kan.hypothesis import * - >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2 - >>> x = torch.normal(0,1,size=(100,8)) - >>> plot_tree(model, x) - ''' - moleculess = get_molecule(model, x, sym_th=sym_th, verbose=verbose) - arities, properties = get_tree_node(model, x, moleculess, sep_th=sep_th, skip_test=skip_sep_test) - - n = x.shape[1] - var = None - - in_vars = [] - - if in_var == None: - for ii in range(1, n + 1): - exec(f"x{ii} = sympy.Symbol('x_{ii}')") - exec(f"in_vars.append(x{ii})") - elif type(var[0]) == Symbol: - in_vars = var - else: - in_vars = [sympy.symbols(var_) for var_ in var] - - - def flatten(xss): - return [x for xs in xss for x in xs] - - def myrectangle(center_x, center_y, width_x, width_y): - plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y + width_y/2, center_y + width_y/2], color='k') # up - plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y - width_y/2], color='k') # down - plt.plot([center_x - width_x/2, center_x - width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left - plt.plot([center_x + width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left - - depth = len(moleculess) - - delta = 1/n - a = 0.3 - b = 0.15 - y0 = 0.5 - - - # draw rectangles - for l in range(depth-1): - molecules = moleculess[l+1] - n_molecule = len(molecules) - - centers = [] - - acc_arity = 0 - - for i in range(n_molecule): - start_id = len(flatten(molecules[:i])) - end_id = len(flatten(molecules[:i+1])) - - center_x = (start_id + (end_id - 1 - start_id)/2) * delta + delta/2 - center_y = (l+1/2)*y0 - width_x = (end_id - start_id - 1 + 2*a)*delta - width_y = 2*b - - # add text (numbers) on rectangles - if style == 'box': - myrectangle(center_x, center_y, width_x, width_y) - plt.text(center_x, center_y, properties[l][i], fontsize=15, horizontalalignment='center', - verticalalignment='center') - elif style == 'tree': - # if 'GS', no rectangle, n=arity tilted lines - # if 'Id', no rectangle, n=arity vertical lines - # if 'Add' or 'Mul'. rectangle, "+" or "x" - # if '', rectangle - property = properties[l][i] - if property == 'GS' or property == 'Add' or property == 'Mul': - color = 'blue' - arity = arities[l][i] - for j in range(arity): - - if l == 0: - # x = (start_id + j) * delta + delta/2, center_x - # y = center_y - b, center_y + b - plt.plot([(start_id + j) * delta + delta/2, center_x], [center_y - b, center_y + b], color=color) - else: - # x = last_centers[acc_arity:acc_arity+arity], center_x - # y = center_y - b, center_y + b - plt.plot([last_centers[acc_arity+j], center_x], [center_y - b, center_y + b], color=color) - - acc_arity += arity - - if property == 'Add' or property == 'Mul': - if property == 'Add': - symbol = '+' - else: - symbol = '*' - - plt.text(center_x, center_y + b, symbol, horizontalalignment='center', - verticalalignment='center', color='red', fontsize=40) - if property == 'Id': - plt.plot([center_x, center_x], [center_y-width_y/2, center_y+width_y/2], color='black') - - if property == '': - myrectangle(center_x, center_y, width_x, width_y) - - - - # connections to the next layer - plt.plot([center_x, center_x], [center_y+width_y/2, center_y+y0-width_y/2], color='k') - centers.append(center_x) - last_centers = copy.deepcopy(centers) - - # connections from input variables to the first layer - for i in range(n): - x_ = (i + 1/2) * delta - # connections to the next layer - plt.plot([x_, x_], [0, y0/2-width_y/2], color='k') - plt.text(x_, -0.05*(depth-1), f'${latex(in_vars[moleculess[0][i][0]])}$', fontsize=20, horizontalalignment='center') - plt.xlim(0,1) - #plt.ylim(0,1); - plt.axis('off'); - plt.show() - - -def test_symmetry_var(model, x, input_vars, symmetry_var): - ''' - test symmetry - - Args: - ----- - model : MultKAN, MLP or python function - x : 2D torch.float - inputs - input_vars : list of sympy symbols - symmetry_var : sympy expression - - Returns: - -------- - cosine similarity - - Example - ------- - >>> from kan.hypothesis import * - >>> from sympy import * - >>> model = lambda x: x[:,[0]] * (x[:,[1]] + x[:,[2]]) - >>> x = torch.normal(0,1,size=(100,8)) - >>> input_vars = a, b, c = symbols('a b c') - >>> symmetry_var = b + c - >>> test_symmetry_var(model, x, input_vars, symmetry_var); - >>> symmetry_var = b * c - >>> test_symmetry_var(model, x, input_vars, symmetry_var); - ''' - orig_vars = input_vars - sym_var = symmetry_var - - # gradients wrt to input (model) - input_grad = batch_jacobian(model, x) - - # gradients wrt to input (symmetry var) - func = lambdify(orig_vars, sym_var,'numpy') # returns a numpy-ready function - - func2 = lambda x: func(*[x[:,[i]] for i in range(len(orig_vars))]) - sym_grad = batch_jacobian(func2, x) - - # get id - idx = [] - sym_symbols = list(sym_var.free_symbols) - for sym_symbol in sym_symbols: - for j in range(len(orig_vars)): - if sym_symbol == orig_vars[j]: - idx.append(j) - - input_grad_part = input_grad[:,idx] - sym_grad_part = sym_grad[:,idx] - - cossim = torch.abs(torch.sum(input_grad_part * sym_grad_part, dim=1)/(torch.norm(input_grad_part, dim=1)*torch.norm(sym_grad_part, dim=1))) - - ratio = torch.sum(cossim > 0.9)/len(cossim) - - print(f'{100*ratio}% data have more than 0.9 cosine similarity') - if ratio > 0.9: - print('suggesting symmetry') - else: - print('not suggesting symmetry') - - return cossim \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/spline.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/spline.py deleted file mode 100644 index d9ce592df86ae4997a5c0960ab993b4db70ae3ba..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/spline.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. -# All Rights Reserved. -import torch - - -def B_batch(x, grid, k=0, extend=True, device='cpu'): - ''' - evaludate x on B-spline bases - - Args: - ----- - x : 2D torch.tensor - inputs, shape (number of splines, number of samples) - grid : 2D torch.tensor - grids, shape (number of splines, number of grid points) - k : int - the piecewise polynomial order of splines. - extend : bool - If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True - device : str - devicde - - Returns: - -------- - spline values : 3D torch.tensor - shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order. - - Example - ------- - >>> from kan.spline import B_batch - >>> x = torch.rand(100,2) - >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11) - >>> B_batch(x, grid, k=3).shape - ''' - - x = x.unsqueeze(dim=2) - grid = grid.unsqueeze(dim=0) - - if k == 0: - value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) - else: - B_km1 = B_batch(x[:,:,0], grid=grid[0], k=k - 1) - - value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + ( - grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:] - - # in case grid is degenerate - value = torch.nan_to_num(value) - return value - - - -def coef2curve(x_eval, grid, coef, k, device="cpu"): - ''' - converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis). - - Args: - ----- - x_eval : 2D torch.tensor - shape (batch, in_dim) - grid : 2D torch.tensor - shape (in_dim, G+2k). G: the number of grid intervals; k: spline order. - coef : 3D torch.tensor - shape (in_dim, out_dim, G+k) - k : int - the piecewise polynomial order of splines. - device : str - devicde - - Returns: - -------- - y_eval : 3D torch.tensor - shape (batch, in_dim, out_dim) - - ''' - - b_splines = B_batch(x_eval, grid, k=k) - y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device)) - - return y_eval - - -def curve2coef(x_eval, y_eval, grid, k): - ''' - converting B-spline curves to B-spline coefficients using least squares. - - Args: - ----- - x_eval : 2D torch.tensor - shape (batch, in_dim) - y_eval : 3D torch.tensor - shape (batch, in_dim, out_dim) - grid : 2D torch.tensor - shape (in_dim, grid+2*k) - k : int - spline order - lamb : float - regularized least square lambda - - Returns: - -------- - coef : 3D torch.tensor - shape (in_dim, out_dim, G+k) - ''' - #print('haha', x_eval.shape, y_eval.shape, grid.shape) - batch = x_eval.shape[0] - in_dim = x_eval.shape[1] - out_dim = y_eval.shape[2] - n_coef = grid.shape[1] - k - 1 - mat = B_batch(x_eval, grid, k) - mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef) - #print('mat', mat.shape) - y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3) - #print('y_eval', y_eval.shape) - device = mat.device - - # coef = torch.linalg.lstsq(mat, y_eval, driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0] - - coef = torch.linalg.lstsq(mat.cpu(), y_eval.cpu(), driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0] - coef = coef.to(device) - # try: - # coef = torch.linalg.lstsq(mat, y_eval).solution[:,:,:,0] - # except: - # print('lstsq failed') - - # manual psuedo-inverse - '''lamb=1e-8 - XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat) - Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval) - n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2] - identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device) - A = XtX + lamb * identity - B = Xty - coef = (A.pinverse() @ B)[:,:,:,0]''' - - return coef - - -def extend_grid(grid, k_extend=0): - ''' - extend grid - ''' - h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) - - for i in range(k_extend): - grid = torch.cat([grid[:, [0]] - h, grid], dim=1) - grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) - - return grid \ No newline at end of file diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/utils.py b/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/utils.py deleted file mode 100644 index abb4d558ba0b8bfd92356f4d41fd1cfcb5e7bf55..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/kan/utils.py +++ /dev/null @@ -1,594 +0,0 @@ -import numpy as np -import torch -from sklearn.linear_model import LinearRegression -import sympy -import yaml -from sympy.utilities.lambdify import lambdify -import re - -# sigmoid = sympy.Function('sigmoid') -# name: (torch implementation, sympy implementation) - -# singularity protection functions -f_inv = lambda x, y_th: ((x_th := 1/y_th), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x) * (torch.abs(x) >= x_th)) -f_inv2 = lambda x, y_th: ((x_th := 1/y_th**(1/2)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**2) * (torch.abs(x) >= x_th)) -f_inv3 = lambda x, y_th: ((x_th := 1/y_th**(1/3)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**3) * (torch.abs(x) >= x_th)) -f_inv4 = lambda x, y_th: ((x_th := 1/y_th**(1/4)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**4) * (torch.abs(x) >= x_th)) -f_inv5 = lambda x, y_th: ((x_th := 1/y_th**(1/5)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**5) * (torch.abs(x) >= x_th)) -f_sqrt = lambda x, y_th: ((x_th := 1/y_th**2), x_th/y_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(torch.sqrt(torch.abs(x))*torch.sign(x)) * (torch.abs(x) >= x_th)) -f_power1d5 = lambda x, y_th: torch.abs(x)**1.5 -f_invsqrt = lambda x, y_th: ((x_th := 1/y_th**2), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/torch.sqrt(torch.abs(x))) * (torch.abs(x) >= x_th)) -f_log = lambda x, y_th: ((x_th := torch.e**(-y_th)), - y_th * (torch.abs(x) < x_th) + torch.nan_to_num(torch.log(torch.abs(x))) * (torch.abs(x) >= x_th)) -f_tan = lambda x, y_th: ((clip := x % torch.pi), (delta := torch.pi/2-torch.arctan(y_th)), - y_th/delta * (clip - torch.pi/2) * (torch.abs(clip - torch.pi/2) < delta) + torch.nan_to_num(torch.tan(clip)) * (torch.abs(clip - torch.pi/2) >= delta)) -f_arctanh = lambda x, y_th: ((delta := 1-torch.tanh(y_th) + 1e-4), y_th * torch.sign(x) * (torch.abs(x) > 1 - delta) + torch.nan_to_num(torch.arctanh(x)) * (torch.abs(x) <= 1 - delta)) -f_arcsin = lambda x, y_th: ((), torch.pi/2 * torch.sign(x) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arcsin(x)) * (torch.abs(x) <= 1)) -f_arccos = lambda x, y_th: ((), torch.pi/2 * (1-torch.sign(x)) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arccos(x)) * (torch.abs(x) <= 1)) -f_exp = lambda x, y_th: ((x_th := torch.log(y_th)), y_th * (x > x_th) + torch.exp(x) * (x <= x_th)) - -SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)), - 'x^2': (lambda x: x**2, lambda x: x**2, 2, lambda x, y_th: ((), x**2)), - 'x^3': (lambda x: x**3, lambda x: x**3, 3, lambda x, y_th: ((), x**3)), - 'x^4': (lambda x: x**4, lambda x: x**4, 3, lambda x, y_th: ((), x**4)), - 'x^5': (lambda x: x**5, lambda x: x**5, 3, lambda x, y_th: ((), x**5)), - '1/x': (lambda x: 1/x, lambda x: 1/x, 2, f_inv), - '1/x^2': (lambda x: 1/x**2, lambda x: 1/x**2, 2, f_inv2), - '1/x^3': (lambda x: 1/x**3, lambda x: 1/x**3, 3, f_inv3), - '1/x^4': (lambda x: 1/x**4, lambda x: 1/x**4, 4, f_inv4), - '1/x^5': (lambda x: 1/x**5, lambda x: 1/x**5, 5, f_inv5), - 'sqrt': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt), - 'x^0.5': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt), - 'x^1.5': (lambda x: torch.sqrt(x)**3, lambda x: sympy.sqrt(x)**3, 4, f_power1d5), - '1/sqrt(x)': (lambda x: 1/torch.sqrt(x), lambda x: 1/sympy.sqrt(x), 2, f_invsqrt), - '1/x^0.5': (lambda x: 1/torch.sqrt(x), lambda x: 1/sympy.sqrt(x), 2, f_invsqrt), - 'exp': (lambda x: torch.exp(x), lambda x: sympy.exp(x), 2, f_exp), - 'log': (lambda x: torch.log(x), lambda x: sympy.log(x), 2, f_log), - 'abs': (lambda x: torch.abs(x), lambda x: sympy.Abs(x), 3, lambda x, y_th: ((), torch.abs(x))), - 'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x), 2, lambda x, y_th: ((), torch.sin(x))), - 'cos': (lambda x: torch.cos(x), lambda x: sympy.cos(x), 2, lambda x, y_th: ((), torch.cos(x))), - 'tan': (lambda x: torch.tan(x), lambda x: sympy.tan(x), 3, f_tan), - 'tanh': (lambda x: torch.tanh(x), lambda x: sympy.tanh(x), 3, lambda x, y_th: ((), torch.tanh(x))), - 'sgn': (lambda x: torch.sign(x), lambda x: sympy.sign(x), 3, lambda x, y_th: ((), torch.sign(x))), - 'arcsin': (lambda x: torch.arcsin(x), lambda x: sympy.asin(x), 4, f_arcsin), - 'arccos': (lambda x: torch.arccos(x), lambda x: sympy.acos(x), 4, f_arccos), - 'arctan': (lambda x: torch.arctan(x), lambda x: sympy.atan(x), 4, lambda x, y_th: ((), torch.arctan(x))), - 'arctanh': (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x), 4, f_arctanh), - '0': (lambda x: x*0, lambda x: x*0, 0, lambda x, y_th: ((), x*0)), - 'gaussian': (lambda x: torch.exp(-x**2), lambda x: sympy.exp(-x**2), 3, lambda x, y_th: ((), torch.exp(-x**2))), - #'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5), - #'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4), - #'relu': (lambda x: torch.relu(x), relu), -} - -def create_dataset(f, - n_var=2, - f_mode = 'col', - ranges = [-1,1], - train_num=1000, - test_num=1000, - normalize_input=False, - normalize_label=False, - device='cpu', - seed=0): - ''' - create dataset - - Args: - ----- - f : function - the symbolic formula used to create the synthetic dataset - ranges : list or np.array; shape (2,) or (n_var, 2) - the range of input variables. Default: [-1,1]. - train_num : int - the number of training samples. Default: 1000. - test_num : int - the number of test samples. Default: 1000. - normalize_input : bool - If True, apply normalization to inputs. Default: False. - normalize_label : bool - If True, apply normalization to labels. Default: False. - device : str - device. Default: 'cpu'. - seed : int - random seed. Default: 0. - - Returns: - -------- - dataset : dic - Train/test inputs/labels are dataset['train_input'], dataset['train_label'], - dataset['test_input'], dataset['test_label'] - - Example - ------- - >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) - >>> dataset = create_dataset(f, n_var=2, train_num=100) - >>> dataset['train_input'].shape - torch.Size([100, 2]) - ''' - - np.random.seed(seed) - torch.manual_seed(seed) - - if len(np.array(ranges).shape) == 1: - ranges = np.array(ranges * n_var).reshape(n_var,2) - else: - ranges = np.array(ranges) - - - train_input = torch.zeros(train_num, n_var) - test_input = torch.zeros(test_num, n_var) - for i in range(n_var): - train_input[:,i] = torch.rand(train_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0] - test_input[:,i] = torch.rand(test_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0] - - if f_mode == 'col': - train_label = f(train_input) - test_label = f(test_input) - elif f_mode == 'row': - train_label = f(train_input.T) - test_label = f(test_input.T) - else: - print(f'f_mode {f_mode} not recognized') - - # if has only 1 dimension - if len(train_label.shape) == 1: - train_label = train_label.unsqueeze(dim=1) - test_label = test_label.unsqueeze(dim=1) - - def normalize(data, mean, std): - return (data-mean)/std - - if normalize_input == True: - mean_input = torch.mean(train_input, dim=0, keepdim=True) - std_input = torch.std(train_input, dim=0, keepdim=True) - train_input = normalize(train_input, mean_input, std_input) - test_input = normalize(test_input, mean_input, std_input) - - if normalize_label == True: - mean_label = torch.mean(train_label, dim=0, keepdim=True) - std_label = torch.std(train_label, dim=0, keepdim=True) - train_label = normalize(train_label, mean_label, std_label) - test_label = normalize(test_label, mean_label, std_label) - - dataset = {} - dataset['train_input'] = train_input.to(device) - dataset['test_input'] = test_input.to(device) - - dataset['train_label'] = train_label.to(device) - dataset['test_label'] = test_label.to(device) - - return dataset - - - -def fit_params(x, y, fun, a_range=(-10,10), b_range=(-10,10), grid_number=101, iteration=3, verbose=True, device='cpu'): - ''' - fit a, b, c, d such that - - .. math:: - |y-(cf(ax+b)+d)|^2 - - is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model. - - Args: - ----- - x : 1D array - x values - y : 1D array - y values - fun : function - symbolic function - a_range : tuple - sweeping range of a - b_range : tuple - sweeping range of b - grid_num : int - number of steps along a and b - iteration : int - number of zooming in - verbose : bool - print extra information if True - device : str - device - - Returns: - -------- - a_best : float - best fitted a - b_best : float - best fitted b - c_best : float - best fitted c - d_best : float - best fitted d - r2_best : float - best r2 (coefficient of determination) - - Example - ------- - >>> num = 100 - >>> x = torch.linspace(-1,1,steps=num) - >>> noises = torch.normal(0,1,(num,)) * 0.02 - >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises - >>> fit_params(x, y, torch.sin) - r2 is 0.9999727010726929 - (tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000)) - ''' - # fit a, b, c, d such that y=c*fun(a*x+b)+d; both x and y are 1D array. - # sweep a and b, choose the best fitted model - for _ in range(iteration): - a_ = torch.linspace(a_range[0], a_range[1], steps=grid_number, device=device) - b_ = torch.linspace(b_range[0], b_range[1], steps=grid_number, device=device) - a_grid, b_grid = torch.meshgrid(a_, b_, indexing='ij') - post_fun = fun(a_grid[None,:,:] * x[:,None,None] + b_grid[None,:,:]) - x_mean = torch.mean(post_fun, dim=[0], keepdim=True) - y_mean = torch.mean(y, dim=[0], keepdim=True) - numerator = torch.sum((post_fun - x_mean)*(y-y_mean)[:,None,None], dim=0)**2 - denominator = torch.sum((post_fun - x_mean)**2, dim=0)*torch.sum((y - y_mean)[:,None,None]**2, dim=0) - r2 = numerator/(denominator+1e-4) - r2 = torch.nan_to_num(r2) - - - best_id = torch.argmax(r2) - a_id, b_id = torch.div(best_id, grid_number, rounding_mode='floor'), best_id % grid_number - - - if a_id == 0 or a_id == grid_number - 1 or b_id == 0 or b_id == grid_number - 1: - if _ == 0 and verbose==True: - print('Best value at boundary.') - if a_id == 0: - a_range = [a_[0], a_[1]] - if a_id == grid_number - 1: - a_range = [a_[-2], a_[-1]] - if b_id == 0: - b_range = [b_[0], b_[1]] - if b_id == grid_number - 1: - b_range = [b_[-2], b_[-1]] - - else: - a_range = [a_[a_id-1], a_[a_id+1]] - b_range = [b_[b_id-1], b_[b_id+1]] - - a_best = a_[a_id] - b_best = b_[b_id] - post_fun = fun(a_best * x + b_best) - r2_best = r2[a_id, b_id] - - if verbose == True: - print(f"r2 is {r2_best}") - if r2_best < 0.9: - print(f'r2 is not very high, please double check if you are choosing the correct symbolic function.') - - post_fun = torch.nan_to_num(post_fun) - reg = LinearRegression().fit(post_fun[:,None].detach().cpu().numpy(), y.detach().cpu().numpy()) - c_best = torch.from_numpy(reg.coef_)[0].to(device) - d_best = torch.from_numpy(np.array(reg.intercept_)).to(device) - return torch.stack([a_best, b_best, c_best, d_best]), r2_best - - -def sparse_mask(in_dim, out_dim): - ''' - get sparse mask - ''' - in_coord = torch.arange(in_dim) * 1/in_dim + 1/(2*in_dim) - out_coord = torch.arange(out_dim) * 1/out_dim + 1/(2*out_dim) - - dist_mat = torch.abs(out_coord[:,None] - in_coord[None,:]) - in_nearest = torch.argmin(dist_mat, dim=0) - in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1,0) - out_nearest = torch.argmin(dist_mat, dim=1) - out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1,0) - all_connection = torch.cat([in_connection, out_connection], dim=0) - mask = torch.zeros(in_dim, out_dim) - mask[all_connection[:,0], all_connection[:,1]] = 1. - - return mask - - -def add_symbolic(name, fun, c=1, fun_singularity=None): - ''' - add a symbolic function to library - - Args: - ----- - name : str - name of the function - fun : fun - torch function or lambda function - - Returns: - -------- - None - - Example - ------- - >>> print(SYMBOLIC_LIB['Bessel']) - KeyError: 'Bessel' - >>> add_symbolic('Bessel', torch.special.bessel_j0) - >>> print(SYMBOLIC_LIB['Bessel']) - (, Bessel) - ''' - exec(f"globals()['{name}'] = sympy.Function('{name}')") - if fun_singularity==None: - fun_singularity = fun - SYMBOLIC_LIB[name] = (fun, globals()[name], c, fun_singularity) - - -def ex_round(ex1, n_digit): - ''' - rounding the numbers in an expression to certain floating points - - Args: - ----- - ex1 : sympy expression - n_digit : int - - Returns: - -------- - ex2 : sympy expression - - Example - ------- - >>> from kan.utils import * - >>> from sympy import * - >>> input_vars = a, b = symbols('a b') - >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402 - >>> ex_round(expression, 2) - ''' - ex2 = ex1 - for a in sympy.preorder_traversal(ex1): - if isinstance(a, sympy.Float): - ex2 = ex2.subs(a, round(a, n_digit)) - return ex2 - - -def augment_input(orig_vars, aux_vars, x): - ''' - augment inputs - - Args: - ----- - orig_vars : list of sympy symbols - aux_vars : list of auxiliary symbols - x : inputs - - Returns: - -------- - augmented inputs - - Example - ------- - >>> from kan.utils import * - >>> from sympy import * - >>> orig_vars = a, b = symbols('a b') - >>> aux_vars = [a + b, a * b] - >>> x = torch.rand(100, 2) - >>> augment_input(orig_vars, aux_vars, x).shape - ''' - # if x is a tensor - if isinstance(x, torch.Tensor): - - aux_values = torch.tensor([]).to(x.device) - - for aux_var in aux_vars: - func = lambdify(orig_vars, aux_var,'numpy') # returns a numpy-ready function - aux_value = torch.from_numpy(func(*[x[:,[i]].numpy() for i in range(len(orig_vars))])) - aux_values = torch.cat([aux_values, aux_value], dim=1) - - x = torch.cat([aux_values, x], dim=1) - - # if x is a dataset - elif isinstance(x, dict): - x['train_input'] = augment_input(orig_vars, aux_vars, x['train_input']) - x['test_input'] = augment_input(orig_vars, aux_vars, x['test_input']) - - return x - - -def batch_jacobian(func, x, create_graph=False, mode='scalar'): - ''' - jacobian - - Args: - ----- - func : function or model - x : inputs - create_graph : bool - - Returns: - -------- - jacobian - - Example - ------- - >>> from kan.utils import batch_jacobian - >>> x = torch.normal(0,1,size=(100,2)) - >>> model = lambda x: x[:,[0]] + x[:,[1]] - >>> batch_jacobian(model, x) - ''' - # x in shape (Batch, Length) - def _func_sum(x): - return func(x).sum(dim=0) - if mode == 'scalar': - return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0] - elif mode == 'vector': - return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2) - -def batch_hessian(model, x, create_graph=False): - ''' - hessian - - Args: - ----- - func : function or model - x : inputs - create_graph : bool - - Returns: - -------- - jacobian - - Example - ------- - >>> from kan.utils import batch_hessian - >>> x = torch.normal(0,1,size=(100,2)) - >>> model = lambda x: x[:,[0]]**2 + x[:,[1]]**2 - >>> batch_hessian(model, x) - ''' - # x in shape (Batch, Length) - jac = lambda x: batch_jacobian(model, x, create_graph=True) - def _jac_sum(x): - return jac(x).sum(dim=0) - return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2) - - -def create_dataset_from_data(inputs, labels, train_ratio=0.8, device='cpu'): - ''' - create dataset from data - - Args: - ----- - inputs : 2D torch.float - labels : 2D torch.float - train_ratio : float - the ratio of training fraction - device : str - - Returns: - -------- - dataset (dictionary) - - Example - ------- - >>> from kan.utils import create_dataset_from_data - >>> x = torch.normal(0,1,size=(100,2)) - >>> y = torch.normal(0,1,size=(100,1)) - >>> dataset = create_dataset_from_data(x, y) - >>> dataset['train_input'].shape - ''' - num = inputs.shape[0] - train_id = np.random.choice(num, int(num*train_ratio), replace=False) - test_id = list(set(np.arange(num)) - set(train_id)) - dataset = {} - dataset['train_input'] = inputs[train_id].detach().to(device) - dataset['test_input'] = inputs[test_id].detach().to(device) - dataset['train_label'] = labels[train_id].detach().to(device) - dataset['test_label'] = labels[test_id].detach().to(device) - - return dataset - - -def get_derivative(model, inputs, labels, derivative='hessian', loss_mode='pred', reg_metric='w', lamb=0., lamb_l1=1., lamb_entropy=0.): - ''' - compute the jacobian/hessian of loss wrt to model parameters - - Args: - ----- - inputs : 2D torch.float - labels : 2D torch.float - derivative : str - 'jacobian' or 'hessian' - device : str - - Returns: - -------- - jacobian or hessian - ''' - def get_mapping(model): - - mapping = {} - name = 'model1' - - keys = list(model.state_dict().keys()) - for key in keys: - - y = re.findall(".[0-9]+", key) - if len(y) > 0: - y = y[0][1:] - x = re.split(".[0-9]+", key) - mapping[key] = name + '.' + x[0] + '[' + y + ']' + x[1] - - - y = re.findall("_[0-9]+", key) - if len(y) > 0: - y = y[0][1:] - x = re.split(".[0-9]+", key) - mapping[key] = name + '.' + x[0] + '[' + y + ']' - - return mapping - - - #model1 = copy.deepcopy(model) - model1 = model.copy() - mapping = get_mapping(model) - - # collect keys and shapes - keys = list(model.state_dict().keys()) - shapes = [] - - for params in model.parameters(): - shapes.append(params.shape) - - - # turn a flattened vector to model params - def param2statedict(p, keys, shapes): - - new_state_dict = {} - - start = 0 - n_group = len(keys) - for i in range(n_group): - shape = shapes[i] - n_params = torch.prod(torch.tensor(shape)) - new_state_dict[keys[i]] = p[start:start+n_params].reshape(shape) - start += n_params - - return new_state_dict - - def differentiable_load_state_dict(mapping, state_dict, model1): - - for key in keys: - if mapping[key][-1] != ']': - exec(f"del {mapping[key]}") - exec(f"{mapping[key]} = state_dict[key]") - - - # input: p, output: output - def get_param2loss_fun(inputs, labels): - - def param2loss_fun(p): - - p = p[0] - state_dict = param2statedict(p, keys, shapes) - # this step is non-differentiable - #model.load_state_dict(state_dict) - differentiable_load_state_dict(mapping, state_dict, model1) - if loss_mode == 'pred': - pred_loss = torch.mean((model1(inputs) - labels)**2, dim=(0,1), keepdim=True) - loss = pred_loss - elif loss_mode == 'reg': - reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1, lamb_entropy=lamb_entropy) * torch.ones(1,1) - loss = reg_loss - elif loss_mode == 'all': - pred_loss = torch.mean((model1(inputs) - labels)**2, dim=(0,1), keepdim=True) - reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1, lamb_entropy=lamb_entropy) * torch.ones(1,1) - loss = pred_loss + lamb * reg_loss - return loss - - return param2loss_fun - - fun = get_param2loss_fun(inputs, labels) - p = model2param(model)[None,:] - if derivative == 'hessian': - result = batch_hessian(fun, p) - elif derivative == 'jacobian': - result = batch_jacobian(fun, p) - return result - -def model2param(model): - ''' - turn model parameters into a flattened vector - ''' - p = torch.tensor([]).to(model.device) - for params in model.parameters(): - p = torch.cat([p, params.reshape(-1,)], dim=0) - return p diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/requirements.txt b/models/others/kolmogorov_arnold_networks/kan/pytorch/requirements.txt index 4ccd831cb92f3d424f1784acc37a91298a524026..0c59fb0abdaf6fdd7835e2c3bdc93a2a47e3d817 100644 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/requirements.txt +++ b/models/others/kolmogorov_arnold_networks/kan/pytorch/requirements.txt @@ -1,4 +1,4 @@ -# matplotlib==3.6.2 +matplotlib==3.6.2 numpy==1.24.4 scikit_learn==1.1.3 setuptools==65.5.0 diff --git a/models/others/kolmogorov_arnold_networks/kan/pytorch/run_train.sh b/models/others/kolmogorov_arnold_networks/kan/pytorch/run_train.sh deleted file mode 100644 index 812a054e41ebdc9058ab75fe611621cd63a9321f..0000000000000000000000000000000000000000 --- a/models/others/kolmogorov_arnold_networks/kan/pytorch/run_train.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash -# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -python3 ./train_kan.py --steps 100 \ No newline at end of file