-
Mirror Mean-Field Langevin Dynamics
Authors:
Anming Gu,
Juno Kim
Abstract:
The mean-field Langevin dynamics (MFLD) minimizes an entropy-regularized nonlinear convex functional on the Wasserstein space over $\mathbb{R}^d$, and has gained attention recently as a model for the gradient descent dynamics of interacting particle systems such as infinite-width two-layer neural networks. However, many problems of interest have constrained domains, which are not solved by existin…
▽ More
The mean-field Langevin dynamics (MFLD) minimizes an entropy-regularized nonlinear convex functional on the Wasserstein space over $\mathbb{R}^d$, and has gained attention recently as a model for the gradient descent dynamics of interacting particle systems such as infinite-width two-layer neural networks. However, many problems of interest have constrained domains, which are not solved by existing mean-field algorithms due to the global diffusion term. We study the optimization of probability measures constrained to a convex subset of $\mathbb{R}^d$ by proposing the \emph{mirror mean-field Langevin dynamics} (MMFLD), an extension of MFLD to the mirror Langevin framework. We obtain linear convergence guarantees for the continuous MMFLD via a uniform log-Sobolev inequality, and uniform-in-time propagation of chaos results for its time- and particle-discretized counterpart.
△ Less
Submitted 5 May, 2025;
originally announced May 2025.
-
Partially Observed Trajectory Inference using Optimal Transport and a Dynamics Prior
Authors:
Anming Gu,
Edward Chien,
Kristjan Greenewald
Abstract:
Trajectory inference seeks to recover the temporal dynamics of a population from snapshots of its (uncoupled) temporal marginals, i.e. where observed particles are not tracked over time. Prior works addressed this challenging problem under a stochastic differential equation (SDE) model with a gradient-driven drift in the observed space, introducing a minimum entropy estimator relative to the Wiene…
▽ More
Trajectory inference seeks to recover the temporal dynamics of a population from snapshots of its (uncoupled) temporal marginals, i.e. where observed particles are not tracked over time. Prior works addressed this challenging problem under a stochastic differential equation (SDE) model with a gradient-driven drift in the observed space, introducing a minimum entropy estimator relative to the Wiener measure and a practical grid-free mean-field Langevin (MFL) algorithm using Schrödinger bridges. Motivated by the success of observable state space models in the traditional paired trajectory inference problem (e.g. target tracking), we extend the above framework to a class of latent SDEs in the form of observable state space models. In this setting, we use partial observations to infer trajectories in the latent space under a specified dynamics model (e.g. the constant velocity/acceleration models from target tracking). We introduce the PO-MFL algorithm to solve this latent trajectory inference problem and provide theoretical guarantees to the partially observed setting. Experiments validate the robustness of our method and the exponential convergence of the MFL dynamics, and demonstrate significant outperformance over the latent-free baseline in key scenarios.
△ Less
Submitted 26 February, 2025; v1 submitted 11 June, 2024;
originally announced June 2024.
-
LeanDojo: Theorem Proving with Retrieval-Augmented Language Models
Authors:
Kaiyu Yang,
Aidan M. Swope,
Alex Gu,
Rahul Chalamala,
Peiyang Song,
Shixing Yu,
Saad Godil,
Ryan Prenger,
Anima Anandkumar
Abstract:
Large language models (LLMs) have shown promise in proving formal theorems using proof assistants such as Lean. However, existing methods are difficult to reproduce or build on, due to private code, data, and large compute requirements. This has created substantial barriers to research on machine learning methods for theorem proving. This paper removes these barriers by introducing LeanDojo: an op…
▽ More
Large language models (LLMs) have shown promise in proving formal theorems using proof assistants such as Lean. However, existing methods are difficult to reproduce or build on, due to private code, data, and large compute requirements. This has created substantial barriers to research on machine learning methods for theorem proving. This paper removes these barriers by introducing LeanDojo: an open-source Lean playground consisting of toolkits, data, models, and benchmarks. LeanDojo extracts data from Lean and enables interaction with the proof environment programmatically. It contains fine-grained annotations of premises in proofs, providing valuable data for premise selection: a key bottleneck in theorem proving. Using this data, we develop ReProver (Retrieval-Augmented Prover): an LLM-based prover augmented with retrieval for selecting premises from a vast math library. It is inexpensive and needs only one GPU week of training. Our retriever leverages LeanDojo's program analysis capability to identify accessible premises and hard negative examples, which makes retrieval much more effective. Furthermore, we construct a new benchmark consisting of 98,734 theorems and proofs extracted from Lean's math library. It features challenging data split requiring the prover to generalize to theorems relying on novel premises that are never used in training. We use this benchmark for training and evaluation, and experimental results demonstrate the effectiveness of ReProver over non-retrieval baselines and GPT-4. We thus provide the first set of open-source LLM-based theorem provers without any proprietary datasets and release it under a permissive MIT license to facilitate further research.
△ Less
Submitted 27 October, 2023; v1 submitted 27 June, 2023;
originally announced June 2023.
-
GIGA-Lens: Fast Bayesian Inference for Strong Gravitational Lens Modeling
Authors:
A. Gu,
X. Huang,
W. Sheu,
G. Aldering,
A. S. Bolton,
K. Boone,
A. Dey,
A. Filipp,
E. Jullo,
S. Perlmutter,
D. Rubin,
E. F. Schlafly,
D. J. Schlegel,
Y. Shu,
S. H. Suyu
Abstract:
We present GIGA-Lens: a gradient-informed, GPU-accelerated Bayesian framework for modeling strong gravitational lensing systems, implemented in TensorFlow and JAX. The three components, optimization using multi-start gradient descent, posterior covariance estimation with variational inference, and sampling via Hamiltonian Monte Carlo, all take advantage of gradient information through automatic di…
▽ More
We present GIGA-Lens: a gradient-informed, GPU-accelerated Bayesian framework for modeling strong gravitational lensing systems, implemented in TensorFlow and JAX. The three components, optimization using multi-start gradient descent, posterior covariance estimation with variational inference, and sampling via Hamiltonian Monte Carlo, all take advantage of gradient information through automatic differentiation and massive parallelization on graphics processing units (GPUs). We test our pipeline on a large set of simulated systems and demonstrate in detail its high level of performance. The average time to model a single system on four Nvidia A100 GPUs is 105 seconds. The robustness, speed, and scalability offered by this framework make it possible to model the large number of strong lenses found in current surveys and present a very promising prospect for the modeling of $\mathcal{O}(10^5)$ lensing systems expected to be discovered in the era of the Vera C. Rubin Observatory, Euclid, and the Nancy Grace Roman Space Telescope.
△ Less
Submitted 15 February, 2022;
originally announced February 2022.
-
Kaleidoscope: An Efficient, Learnable Representation For All Structured Linear Maps
Authors:
Tri Dao,
Nimit S. Sohoni,
Albert Gu,
Matthew Eichhorn,
Amit Blonder,
Megan Leszczynski,
Atri Rudra,
Christopher Ré
Abstract:
Modern neural network architectures use structured linear transformations, such as low-rank matrices, sparse matrices, permutations, and the Fourier transform, to improve inference speed and reduce memory usage compared to general linear maps. However, choosing which of the myriad structured transformations to use (and its associated parameterization) is a laborious task that requires trading off…
▽ More
Modern neural network architectures use structured linear transformations, such as low-rank matrices, sparse matrices, permutations, and the Fourier transform, to improve inference speed and reduce memory usage compared to general linear maps. However, choosing which of the myriad structured transformations to use (and its associated parameterization) is a laborious task that requires trading off speed, space, and accuracy. We consider a different approach: we introduce a family of matrices called kaleidoscope matrices (K-matrices) that provably capture any structured matrix with near-optimal space (parameter) and time (arithmetic operation) complexity. We empirically validate that K-matrices can be automatically learned within end-to-end pipelines to replace hand-crafted procedures, in order to improve model quality. For example, replacing channel shuffles in ShuffleNet improves classification accuracy on ImageNet by up to 5%. K-matrices can also simplify hand-engineered pipelines -- we replace filter bank feature computation in speech data preprocessing with a learnable kaleidoscope layer, resulting in only 0.4% loss in accuracy on the TIMIT speech recognition task. In addition, K-matrices can capture latent structure in models: for a challenging permuted image classification task, a K-matrix based representation of permutations is able to learn the right latent structure and improves accuracy of a downstream convolutional model by over 9%. We provide a practically efficient implementation of our approach, and use K-matrices in a Transformer network to attain 36% faster end-to-end inference speed on a language translation task.
△ Less
Submitted 5 January, 2021; v1 submitted 29 December, 2020;
originally announced December 2020.
-
From Trees to Continuous Embeddings and Back: Hyperbolic Hierarchical Clustering
Authors:
Ines Chami,
Albert Gu,
Vaggos Chatziafratis,
Christopher Ré
Abstract:
Similarity-based Hierarchical Clustering (HC) is a classical unsupervised machine learning algorithm that has traditionally been solved with heuristic algorithms like Average-Linkage. Recently, Dasgupta reframed HC as a discrete optimization problem by introducing a global cost function measuring the quality of a given tree. In this work, we provide the first continuous relaxation of Dasgupta's di…
▽ More
Similarity-based Hierarchical Clustering (HC) is a classical unsupervised machine learning algorithm that has traditionally been solved with heuristic algorithms like Average-Linkage. Recently, Dasgupta reframed HC as a discrete optimization problem by introducing a global cost function measuring the quality of a given tree. In this work, we provide the first continuous relaxation of Dasgupta's discrete optimization problem with provable quality guarantees. The key idea of our method, HypHC, is showing a direct correspondence from discrete trees to continuous representations (via the hyperbolic embeddings of their leaf nodes) and back (via a decoding algorithm that maps leaf embeddings to a dendrogram), allowing us to search the space of discrete binary trees with continuous optimization. Building on analogies between trees and hyperbolic space, we derive a continuous analogue for the notion of lowest common ancestor, which leads to a continuous relaxation of Dasgupta's discrete objective. We can show that after decoding, the global minimizer of our continuous relaxation yields a discrete tree with a (1 + epsilon)-factor approximation for Dasgupta's optimal tree, where epsilon can be made arbitrarily small and controls optimization challenges. We experimentally evaluate HypHC on a variety of HC benchmarks and find that even approximate solutions found with gradient descent have superior clustering quality than agglomerative heuristics or other gradient based algorithms. Finally, we highlight the flexibility of HypHC using end-to-end training in a downstream classification task.
△ Less
Submitted 1 October, 2020;
originally announced October 2020.
-
Using Undersampling with Ensemble Learning to Identify Factors Contributing to Preterm Birth
Authors:
Shi Dong,
Zlatan Feric,
Guangyu Li,
Chieh Wu,
April Z. Gu,
Jennifer Dy,
John Meeker,
Ingrid Y. Padilla,
Jose Cordero,
Carmen Velez Vega,
Zaira Rosario,
Akram Alshawabkeh,
David Kaeli
Abstract:
In this paper, we propose Ensemble Learning models to identify factors contributing to preterm birth. Our work leverages a rich dataset collected by a NIEHS P42 Center that is trying to identify the dominant factors responsible for the high rate of premature births in northern Puerto Rico. We investigate analytical models addressing two major challenges present in the dataset: 1) the significant a…
▽ More
In this paper, we propose Ensemble Learning models to identify factors contributing to preterm birth. Our work leverages a rich dataset collected by a NIEHS P42 Center that is trying to identify the dominant factors responsible for the high rate of premature births in northern Puerto Rico. We investigate analytical models addressing two major challenges present in the dataset: 1) the significant amount of incomplete data in the dataset, and 2) class imbalance in the dataset. First, we leverage and compare two types of missing data imputation methods: 1) mean-based and 2) similarity-based, increasing the completeness of this dataset. Second, we propose a feature selection and evaluation model based on using undersampling with Ensemble Learning to address class imbalance present in the dataset. We leverage and compare multiple Ensemble Feature selection methods, including Complete Linear Aggregation (CLA), Weighted Mean Aggregation (WMA), Feature Occurrence Frequency (OFA), and Classification Accuracy Based Aggregation (CAA). To further address missing data present in each feature, we propose two novel methods: 1) Missing Data Rate and Accuracy Based Aggregation (MAA), and 2) Entropy and Accuracy Based Aggregation (EAA). Both proposed models balance the degree of data variance introduced by the missing data handling during the feature selection process while maintaining model performance. Our results show a 42\% improvement in sensitivity versus fallout over previous state-of-the-art methods.
△ Less
Submitted 23 September, 2020;
originally announced September 2020.
-
HiPPO: Recurrent Memory with Optimal Polynomial Projections
Authors:
Albert Gu,
Tri Dao,
Stefano Ermon,
Atri Rudra,
Christopher Re
Abstract:
A central problem in learning from sequential data is representing cumulative history in an incremental fashion as more data is processed. We introduce a general framework (HiPPO) for the online compression of continuous signals and discrete time series by projection onto polynomial bases. Given a measure that specifies the importance of each time step in the past, HiPPO produces an optimal soluti…
▽ More
A central problem in learning from sequential data is representing cumulative history in an incremental fashion as more data is processed. We introduce a general framework (HiPPO) for the online compression of continuous signals and discrete time series by projection onto polynomial bases. Given a measure that specifies the importance of each time step in the past, HiPPO produces an optimal solution to a natural online function approximation problem. As special cases, our framework yields a short derivation of the recent Legendre Memory Unit (LMU) from first principles, and generalizes the ubiquitous gating mechanism of recurrent neural networks such as GRUs. This formal framework yields a new memory update mechanism (HiPPO-LegS) that scales through time to remember all history, avoiding priors on the timescale. HiPPO-LegS enjoys the theoretical benefits of timescale robustness, fast updates, and bounded gradients. By incorporating the memory dynamics into recurrent neural networks, HiPPO RNNs can empirically capture complex temporal dependencies. On the benchmark permuted MNIST dataset, HiPPO-LegS sets a new state-of-the-art accuracy of 98.3%. Finally, on a novel trajectory classification task testing robustness to out-of-distribution timescales and missing data, HiPPO-LegS outperforms RNN and neural ODE baselines by 25-40% accuracy.
△ Less
Submitted 22 October, 2020; v1 submitted 17 August, 2020;
originally announced August 2020.
-
Model Patching: Closing the Subgroup Performance Gap with Data Augmentation
Authors:
Karan Goel,
Albert Gu,
Yixuan Li,
Christopher Ré
Abstract:
Classifiers in machine learning are often brittle when deployed. Particularly concerning are models with inconsistent performance on specific subgroups of a class, e.g., exhibiting disparities in skin cancer classification in the presence or absence of a spurious bandage. To mitigate these performance differences, we introduce model patching, a two-stage framework for improving robustness that enc…
▽ More
Classifiers in machine learning are often brittle when deployed. Particularly concerning are models with inconsistent performance on specific subgroups of a class, e.g., exhibiting disparities in skin cancer classification in the presence or absence of a spurious bandage. To mitigate these performance differences, we introduce model patching, a two-stage framework for improving robustness that encourages the model to be invariant to subgroup differences, and focus on class information shared by subgroups. Model patching first models subgroup features within a class and learns semantic transformations between them, and then trains a classifier with data augmentations that deliberately manipulate subgroup features. We instantiate model patching with CAMEL, which (1) uses a CycleGAN to learn the intra-class, inter-subgroup augmentations, and (2) balances subgroup performance using a theoretically-motivated subgroup consistency regularizer, accompanied by a new robust objective. We demonstrate CAMEL's effectiveness on 3 benchmark datasets, with reductions in robust error of up to 33% relative to the best baseline. Lastly, CAMEL successfully patches a model that fails due to spurious features on a real-world skin cancer dataset.
△ Less
Submitted 15 August, 2020;
originally announced August 2020.
-
Learning Fast Algorithms for Linear Transforms Using Butterfly Factorizations
Authors:
Tri Dao,
Albert Gu,
Matthew Eichhorn,
Atri Rudra,
Christopher Ré
Abstract:
Fast linear transforms are ubiquitous in machine learning, including the discrete Fourier transform, discrete cosine transform, and other structured transformations such as convolutions. All of these transforms can be represented by dense matrix-vector multiplication, yet each has a specialized and highly efficient (subquadratic) algorithm. We ask to what extent hand-crafting these algorithms and…
▽ More
Fast linear transforms are ubiquitous in machine learning, including the discrete Fourier transform, discrete cosine transform, and other structured transformations such as convolutions. All of these transforms can be represented by dense matrix-vector multiplication, yet each has a specialized and highly efficient (subquadratic) algorithm. We ask to what extent hand-crafting these algorithms and implementations is necessary, what structural priors they encode, and how much knowledge is required to automatically learn a fast algorithm for a provided structured transform. Motivated by a characterization of fast matrix-vector multiplication as products of sparse matrices, we introduce a parameterization of divide-and-conquer methods that is capable of representing a large class of transforms. This generic formulation can automatically learn an efficient algorithm for many important transforms; for example, it recovers the $O(N \log N)$ Cooley-Tukey FFT algorithm to machine precision, for dimensions $N$ up to $1024$. Furthermore, our method can be incorporated as a lightweight replacement of generic matrices in machine learning pipelines to learn efficient and compressible transformations. On a standard task of compressing a single hidden-layer network, our method exceeds the classification accuracy of unconstrained matrices on CIFAR-10 by 3.9 points -- the first time a structured approach has done so -- with 4X faster inference speed and 40X fewer parameters.
△ Less
Submitted 28 December, 2020; v1 submitted 14 March, 2019;
originally announced March 2019.
-
Learning Compressed Transforms with Low Displacement Rank
Authors:
Anna T. Thomas,
Albert Gu,
Tri Dao,
Atri Rudra,
Christopher Ré
Abstract:
The low displacement rank (LDR) framework for structured matrices represents a matrix through two displacement operators and a low-rank residual. Existing use of LDR matrices in deep learning has applied fixed displacement operators encoding forms of shift invariance akin to convolutions. We introduce a class of LDR matrices with more general displacement operators, and explicitly learn over both…
▽ More
The low displacement rank (LDR) framework for structured matrices represents a matrix through two displacement operators and a low-rank residual. Existing use of LDR matrices in deep learning has applied fixed displacement operators encoding forms of shift invariance akin to convolutions. We introduce a class of LDR matrices with more general displacement operators, and explicitly learn over both the operators and the low-rank component. This class generalizes several previous constructions while preserving compression and efficient computation. We prove bounds on the VC dimension of multi-layer neural networks with structured weight matrices and show empirically that our compact parameterization can reduce the sample complexity of learning. When replacing weight layers in fully-connected, convolutional, and recurrent neural networks for image classification and language modeling tasks, our new classes exceed the accuracy of existing compression approaches, and on some tasks also outperform general unstructured layers while using more than 20x fewer parameters.
△ Less
Submitted 1 January, 2019; v1 submitted 4 October, 2018;
originally announced October 2018.
-
Representation Tradeoffs for Hyperbolic Embeddings
Authors:
Christopher De Sa,
Albert Gu,
Christopher Ré,
Frederic Sala
Abstract:
Hyperbolic embeddings offer excellent quality with few dimensions when embedding hierarchical data structures like synonym or type hierarchies. Given a tree, we give a combinatorial construction that embeds the tree in hyperbolic space with arbitrarily low distortion without using optimization. On WordNet, our combinatorial embedding obtains a mean-average-precision of 0.989 with only two dimensio…
▽ More
Hyperbolic embeddings offer excellent quality with few dimensions when embedding hierarchical data structures like synonym or type hierarchies. Given a tree, we give a combinatorial construction that embeds the tree in hyperbolic space with arbitrarily low distortion without using optimization. On WordNet, our combinatorial embedding obtains a mean-average-precision of 0.989 with only two dimensions, while Nickel et al.'s recent construction obtains 0.87 using 200 dimensions. We provide upper and lower bounds that allow us to characterize the precision-dimensionality tradeoff inherent in any hyperbolic embedding. To embed general metric spaces, we propose a hyperbolic generalization of multidimensional scaling (h-MDS). We show how to perform exact recovery of hyperbolic points from distances, provide a perturbation analysis, and give a recovery result that allows us to reduce dimensionality. The h-MDS approach offers consistently low distortion even with few dimensions across several datasets. Finally, we extract lessons from the algorithms and theory above to design a PyTorch-based implementation that can handle incomplete information and is scalable.
△ Less
Submitted 24 April, 2018; v1 submitted 9 April, 2018;
originally announced April 2018.
-
A Kernel Theory of Modern Data Augmentation
Authors:
Tri Dao,
Albert Gu,
Alexander J. Ratner,
Virginia Smith,
Christopher De Sa,
Christopher Ré
Abstract:
Data augmentation, a technique in which a training set is expanded with class-preserving transformations, is ubiquitous in modern machine learning pipelines. In this paper, we seek to establish a theoretical framework for understanding data augmentation. We approach this from two directions: First, we provide a general model of augmentation as a Markov process, and show that kernels appear natural…
▽ More
Data augmentation, a technique in which a training set is expanded with class-preserving transformations, is ubiquitous in modern machine learning pipelines. In this paper, we seek to establish a theoretical framework for understanding data augmentation. We approach this from two directions: First, we provide a general model of augmentation as a Markov process, and show that kernels appear naturally with respect to this model, even when we do not employ kernel classification. Next, we analyze more directly the effect of augmentation on kernel classifiers, showing that data augmentation can be approximated by first-order feature averaging and second-order variance regularization components. These frameworks both serve to illustrate the ways in which data augmentation affects the downstream learning model, and the resulting analyses provide novel connections between prior work in invariant kernels, tangent propagation, and robust optimization. Finally, we provide several proof-of-concept applications showing that our theory can be useful for accelerating machine learning workflows, such as reducing the amount of computation needed to train using augmented data, and predicting the utility of a transformation prior to training.
△ Less
Submitted 20 March, 2019; v1 submitted 16 March, 2018;
originally announced March 2018.
-
Sector-Based Factor Models for Asset Returns
Authors:
Angela Gu,
Patrick Zeng
Abstract:
Factor analysis is a statistical technique employed to evaluate how observed variables correlate through common factors and unique variables. While it is often used to analyze price movement in the unstable stock market, it does not always yield easily interpretable results. In this study, we develop improved factor models by explicitly incorporating sector information on our studied stocks. We ad…
▽ More
Factor analysis is a statistical technique employed to evaluate how observed variables correlate through common factors and unique variables. While it is often used to analyze price movement in the unstable stock market, it does not always yield easily interpretable results. In this study, we develop improved factor models by explicitly incorporating sector information on our studied stocks. We add eleven sectors of stocks as defined by the IBES, represented by respective sector-specific factors, to non-specific market factors to revise the factor model. We then develop an expectation maximization (EM) algorithm to compute our revised model with 15 years' worth of S&P 500 stocks' daily close prices. Our results in most sectors show that nearly all of these factor components have the same sign, consistent with the intuitive idea that stocks in the same sector tend to rise and fall in coordination over time. Results obtained by the classic factor model, in contrast, had a homogeneous blend of positive and negative components. We conclude that results produced by our sector-based factor model are more interpretable than those produced by the classic non-sector-based model for at least some stock sectors.
△ Less
Submitted 11 August, 2014;
originally announced August 2014.