-
Test-time regression: a unifying framework for designing sequence models with associative memory
Authors:
Ke Alexander Wang,
Jiaxin Shi,
Emily B. Fox
Abstract:
Sequence models lie at the heart of modern deep learning. However, rapid advancements have produced a diversity of seemingly unrelated architectures, such as Transformers and recurrent alternatives. In this paper, we introduce a unifying framework to understand and derive these sequence models, inspired by the empirical importance of associative recall, the capability to retrieve contextually rele…
▽ More
Sequence models lie at the heart of modern deep learning. However, rapid advancements have produced a diversity of seemingly unrelated architectures, such as Transformers and recurrent alternatives. In this paper, we introduce a unifying framework to understand and derive these sequence models, inspired by the empirical importance of associative recall, the capability to retrieve contextually relevant tokens. We formalize associative recall as a two-step process, memorization and retrieval, casting memorization as a regression problem. Layers that combine these two steps perform associative recall via ``test-time regression'' over its input tokens. Prominent layers, including linear attention, state-space models, fast-weight programmers, online learners, and softmax attention, arise as special cases defined by three design choices: the regression weights, the regressor function class, and the test-time optimization algorithm. Our approach clarifies how linear attention fails to capture inter-token correlations and offers a mathematical justification for the empirical effectiveness of query-key normalization in softmax attention. Further, it illuminates unexplored regions within the design space, which we use to derive novel higher-order generalizations of softmax attention. Beyond unification, our work bridges sequence modeling with classic regression methods, a field with extensive literature, paving the way for developing more powerful and theoretically principled architectures.
△ Less
Submitted 1 May, 2025; v1 submitted 21 January, 2025;
originally announced January 2025.
-
Interpretable Mechanistic Representations for Meal-level Glycemic Control in the Wild
Authors:
Ke Alexander Wang,
Emily B. Fox
Abstract:
Diabetes encompasses a complex landscape of glycemic control that varies widely among individuals. However, current methods do not faithfully capture this variability at the meal level. On the one hand, expert-crafted features lack the flexibility of data-driven methods; on the other hand, learned representations tend to be uninterpretable which hampers clinical adoption. In this paper, we propose…
▽ More
Diabetes encompasses a complex landscape of glycemic control that varies widely among individuals. However, current methods do not faithfully capture this variability at the meal level. On the one hand, expert-crafted features lack the flexibility of data-driven methods; on the other hand, learned representations tend to be uninterpretable which hampers clinical adoption. In this paper, we propose a hybrid variational autoencoder to learn interpretable representations of CGM and meal data. Our method grounds the latent space to the inputs of a mechanistic differential equation, producing embeddings that reflect physiological quantities, such as insulin sensitivity, glucose effectiveness, and basal glucose levels. Moreover, we introduce a novel method to infer the glucose appearance rate, making the mechanistic model robust to unreliable meal logs. On a dataset of CGM and self-reported meals from individuals with type-2 diabetes and pre-diabetes, our unsupervised representation discovers a separation between individuals proportional to their disease severity. Our embeddings produce clusters that are up to 4x better than naive, expert, black-box, and pure mechanistic features. Our method provides a nuanced, yet interpretable, embedding space to compare glycemic control within and across individuals, directly learnable from in-the-wild data.
△ Less
Submitted 6 December, 2023;
originally announced December 2023.
-
Sequence Modeling with Multiresolution Convolutional Memory
Authors:
Jiaxin Shi,
Ke Alexander Wang,
Emily B. Fox
Abstract:
Efficiently capturing the long-range patterns in sequential data sources salient to a given task -- such as classification and generative modeling -- poses a fundamental challenge. Popular approaches in the space tradeoff between the memory burden of brute-force enumeration and comparison, as in transformers, the computational burden of complicated sequential dependencies, as in recurrent neural n…
▽ More
Efficiently capturing the long-range patterns in sequential data sources salient to a given task -- such as classification and generative modeling -- poses a fundamental challenge. Popular approaches in the space tradeoff between the memory burden of brute-force enumeration and comparison, as in transformers, the computational burden of complicated sequential dependencies, as in recurrent neural networks, or the parameter burden of convolutional networks with many or large filters. We instead take inspiration from wavelet-based multiresolution analysis to define a new building block for sequence modeling, which we call a MultiresLayer. The key component of our model is the multiresolution convolution, capturing multiscale trends in the input sequence. Our MultiresConv can be implemented with shared filters across a dilated causal convolution tree. Thus it garners the computational advantages of convolutional networks and the principled theoretical motivation of wavelet decompositions. Our MultiresLayer is straightforward to implement, requires significantly fewer parameters, and maintains at most a $\mathcal{O}(N\log N)$ memory footprint for a length $N$ sequence. Yet, by stacking such layers, our model yields state-of-the-art performance on a number of sequence classification and autoregressive density estimation tasks using CIFAR-10, ListOps, and PTB-XL datasets.
△ Less
Submitted 1 November, 2023; v1 submitted 2 May, 2023;
originally announced May 2023.
-
Is Importance Weighting Incompatible with Interpolating Classifiers?
Authors:
Ke Alexander Wang,
Niladri S. Chatterji,
Saminul Haque,
Tatsunori Hashimoto
Abstract:
Importance weighting is a classic technique to handle distribution shifts. However, prior work has presented strong empirical and theoretical evidence demonstrating that importance weights can have little to no effect on overparameterized neural networks. Is importance weighting truly incompatible with the training of overparameterized neural networks? Our paper answers this in the negative. We sh…
▽ More
Importance weighting is a classic technique to handle distribution shifts. However, prior work has presented strong empirical and theoretical evidence demonstrating that importance weights can have little to no effect on overparameterized neural networks. Is importance weighting truly incompatible with the training of overparameterized neural networks? Our paper answers this in the negative. We show that importance weighting fails not because of the overparameterization, but instead, as a result of using exponentially-tailed losses like the logistic or cross-entropy loss. As a remedy, we show that polynomially-tailed losses restore the effects of importance reweighting in correcting distribution shift in overparameterized models. We characterize the behavior of gradient descent on importance weighted polynomially-tailed losses with overparameterized linear models, and theoretically demonstrate the advantage of using polynomially-tailed losses in a label shift setting. Surprisingly, our theory shows that using weights that are obtained by exponentiating the classical unbiased importance weights can improve performance. Finally, we demonstrate the practical value of our analysis with neural network experiments on a subpopulation shift and a label shift dataset. When reweighted, our loss function can outperform reweighted cross-entropy by as much as 9% in test accuracy. Our loss function also gives test accuracies comparable to, or even exceeding, well-tuned state-of-the-art methods for correcting distribution shifts.
△ Less
Submitted 4 March, 2022; v1 submitted 24 December, 2021;
originally announced December 2021.
-
SKIing on Simplices: Kernel Interpolation on the Permutohedral Lattice for Scalable Gaussian Processes
Authors:
Sanyam Kapoor,
Marc Finzi,
Ke Alexander Wang,
Andrew Gordon Wilson
Abstract:
State-of-the-art methods for scalable Gaussian processes use iterative algorithms, requiring fast matrix vector multiplies (MVMs) with the covariance kernel. The Structured Kernel Interpolation (SKI) framework accelerates these MVMs by performing efficient MVMs on a grid and interpolating back to the original space. In this work, we develop a connection between SKI and the permutohedral lattice us…
▽ More
State-of-the-art methods for scalable Gaussian processes use iterative algorithms, requiring fast matrix vector multiplies (MVMs) with the covariance kernel. The Structured Kernel Interpolation (SKI) framework accelerates these MVMs by performing efficient MVMs on a grid and interpolating back to the original space. In this work, we develop a connection between SKI and the permutohedral lattice used for high-dimensional fast bilateral filtering. Using a sparse simplicial grid instead of a dense rectangular one, we can perform GP inference exponentially faster in the dimension than SKI. Our approach, Simplex-GP, enables scaling SKI to high dimensions, while maintaining strong predictive performance. We additionally provide a CUDA implementation of Simplex-GP, which enables significant GPU acceleration of MVM based inference.
△ Less
Submitted 12 June, 2021;
originally announced June 2021.
-
Bayesian Algorithm Execution: Estimating Computable Properties of Black-box Functions Using Mutual Information
Authors:
Willie Neiswanger,
Ke Alexander Wang,
Stefano Ermon
Abstract:
In many real-world problems, we want to infer some property of an expensive black-box function $f$, given a budget of $T$ function evaluations. One example is budget constrained global optimization of $f$, for which Bayesian optimization is a popular method. Other properties of interest include local optima, level sets, integrals, or graph-structured information induced by $f$. Often, we can find…
▽ More
In many real-world problems, we want to infer some property of an expensive black-box function $f$, given a budget of $T$ function evaluations. One example is budget constrained global optimization of $f$, for which Bayesian optimization is a popular method. Other properties of interest include local optima, level sets, integrals, or graph-structured information induced by $f$. Often, we can find an algorithm $\mathcal{A}$ to compute the desired property, but it may require far more than $T$ queries to execute. Given such an $\mathcal{A}$, and a prior distribution over $f$, we refer to the problem of inferring the output of $\mathcal{A}$ using $T$ evaluations as Bayesian Algorithm Execution (BAX). To tackle this problem, we present a procedure, InfoBAX, that sequentially chooses queries that maximize mutual information with respect to the algorithm's output. Applying this to Dijkstra's algorithm, for instance, we infer shortest paths in synthetic and real-world graphs with black-box edge costs. Using evolution strategies, we yield variants of Bayesian optimization that target local, rather than global, optima. On these problems, InfoBAX uses up to 500 times fewer queries to $f$ than required by the original algorithm. Our method is closely connected to other Bayesian optimal experimental design procedures such as entropy search methods and optimal sensor placement using Gaussian processes.
△ Less
Submitted 6 July, 2021; v1 submitted 19 April, 2021;
originally announced April 2021.
-
Simplifying Hamiltonian and Lagrangian Neural Networks via Explicit Constraints
Authors:
Marc Finzi,
Ke Alexander Wang,
Andrew Gordon Wilson
Abstract:
Reasoning about the physical world requires models that are endowed with the right inductive biases to learn the underlying dynamics. Recent works improve generalization for predicting trajectories by learning the Hamiltonian or Lagrangian of a system rather than the differential equations directly. While these methods encode the constraints of the systems using generalized coordinates, we show th…
▽ More
Reasoning about the physical world requires models that are endowed with the right inductive biases to learn the underlying dynamics. Recent works improve generalization for predicting trajectories by learning the Hamiltonian or Lagrangian of a system rather than the differential equations directly. While these methods encode the constraints of the systems using generalized coordinates, we show that embedding the system into Cartesian coordinates and enforcing the constraints explicitly with Lagrange multipliers dramatically simplifies the learning problem. We introduce a series of challenging chaotic and extended-body systems, including systems with N-pendulums, spring coupling, magnetic fields, rigid rotors, and gyroscopes, to push the limits of current approaches. Our experiments show that Cartesian coordinates with explicit constraints lead to a 100x improvement in accuracy and data efficiency.
△ Less
Submitted 26 October, 2020;
originally announced October 2020.
-
$DC^2$: A Divide-and-conquer Algorithm for Large-scale Kernel Learning with Application to Clustering
Authors:
Ke Alexander Wang,
Xinran Bian,
Pan Liu,
Donghui Yan
Abstract:
Divide-and-conquer is a general strategy to deal with large scale problems. It is typically applied to generate ensemble instances, which potentially limits the problem size it can handle. Additionally, the data are often divided by random sampling which may be suboptimal. To address these concerns, we propose the $DC^2$ algorithm. Instead of ensemble instances, we produce structure-preserving sig…
▽ More
Divide-and-conquer is a general strategy to deal with large scale problems. It is typically applied to generate ensemble instances, which potentially limits the problem size it can handle. Additionally, the data are often divided by random sampling which may be suboptimal. To address these concerns, we propose the $DC^2$ algorithm. Instead of ensemble instances, we produce structure-preserving signature pieces to be assembled and conquered. $DC^2$ achieves the efficiency of sampling-based large scale kernel methods while enabling parallel multicore or clustered computation. The data partition and subsequent compression are unified by recursive random projections. Empirically dividing the data by random projections induces smaller mean squared approximation errors than conventional random sampling. The power of $DC^2$ is demonstrated by our clustering algorithm $rpfCluster^+$, which is as accurate as some fastest approximate spectral clustering algorithms while maintaining a running time close to that of K-means clustering. Analysis on $DC^2$ when applied to spectral clustering shows that the loss in clustering accuracy due to data division and reduction is upper bounded by the data approximation error which would vanish with recursive random projections. Due to its easy implementation and flexibility, we expect $DC^2$ to be applicable to general large scale learning problems.
△ Less
Submitted 15 November, 2019;
originally announced November 2019.
-
Exact Gaussian Processes on a Million Data Points
Authors:
Ke Alexander Wang,
Geoff Pleiss,
Jacob R. Gardner,
Stephen Tyree,
Kilian Q. Weinberger,
Andrew Gordon Wilson
Abstract:
Gaussian processes (GPs) are flexible non-parametric models, with a capacity that grows with the available data. However, computational constraints with standard inference procedures have limited exact GPs to problems with fewer than about ten thousand training points, necessitating approximations for larger datasets. In this paper, we develop a scalable approach for exact GPs that leverages multi…
▽ More
Gaussian processes (GPs) are flexible non-parametric models, with a capacity that grows with the available data. However, computational constraints with standard inference procedures have limited exact GPs to problems with fewer than about ten thousand training points, necessitating approximations for larger datasets. In this paper, we develop a scalable approach for exact GPs that leverages multi-GPU parallelization and methods like linear conjugate gradients, accessing the kernel matrix only through matrix multiplication. By partitioning and distributing kernel matrix multiplies, we demonstrate that an exact GP can be trained on over a million points, a task previously thought to be impossible with current computing hardware, in less than 2 hours. Moreover, our approach is generally applicable, without constraints to grid data or specific kernel classes. Enabled by this scalability, we perform the first-ever comparison of exact GPs against scalable GP approximations on datasets with $10^4 \!-\! 10^6$ data points, showing dramatic performance improvements.
△ Less
Submitted 10 December, 2019; v1 submitted 19 March, 2019;
originally announced March 2019.