-
Learning single-index models via harmonic decomposition
Authors:
Nirmit Joshi,
Hugo Koubbi,
Theodor Misiakiewicz,
Nathan Srebro
Abstract:
We study the problem of learning single-index models, where the label $y \in \mathbb{R}$ depends on the input $\boldsymbol{x} \in \mathbb{R}^d$ only through an unknown one-dimensional projection $\langle \boldsymbol{w}_*,\boldsymbol{x}\rangle$. Prior work has shown that under Gaussian inputs, the statistical and computational complexity of recovering $\boldsymbol{w}_*$ is governed by the Hermite e…
▽ More
We study the problem of learning single-index models, where the label $y \in \mathbb{R}$ depends on the input $\boldsymbol{x} \in \mathbb{R}^d$ only through an unknown one-dimensional projection $\langle \boldsymbol{w}_*,\boldsymbol{x}\rangle$. Prior work has shown that under Gaussian inputs, the statistical and computational complexity of recovering $\boldsymbol{w}_*$ is governed by the Hermite expansion of the link function. In this paper, we propose a new perspective: we argue that "spherical harmonics" -- rather than "Hermite polynomials" -- provide the natural basis for this problem, as they capture its intrinsic "rotational symmetry". Building on this insight, we characterize the complexity of learning single-index models under arbitrary spherically symmetric input distributions. We introduce two families of estimators -- based on tensor unfolding and online SGD -- that respectively achieve either optimal sample complexity or optimal runtime, and argue that estimators achieving both may not exist in general. When specialized to Gaussian inputs, our theory not only recovers and clarifies existing results but also reveals new phenomena that had previously been overlooked.
△ Less
Submitted 11 June, 2025;
originally announced June 2025.
-
Enhancing Neural Autoregressive Distribution Estimators for Image Reconstruction
Authors:
Ambrose Emmett-Iwaniw,
Nathan Kirk
Abstract:
Autoregressive models are often employed to learn distributions of image data by decomposing the $D$-dimensional density function into a product of one-dimensional conditional distributions. Each conditional depends on preceding variables (pixels, in the case of image data), making the order in which variables are processed fundamental to the model performance. In this paper, we study the problem…
▽ More
Autoregressive models are often employed to learn distributions of image data by decomposing the $D$-dimensional density function into a product of one-dimensional conditional distributions. Each conditional depends on preceding variables (pixels, in the case of image data), making the order in which variables are processed fundamental to the model performance. In this paper, we study the problem of observing a small subset of image pixels (referred to as a pixel patch) to predict the unobserved parts of the image. As our prediction mechanism, we propose a generalized and computationally efficient version of the convolutional neural autoregressive distribution estimator (ConvNADE) model adapted for real-valued and color images. Moreover, we investigate the quality of image reconstruction when observing both random pixel patches and low-discrepancy pixel patches inspired by quasi-Monte Carlo theory. Experiments on benchmark datasets demonstrate that choosing the pixels akin to a low-discrepancy sequence reduces test loss and produces more realistic reconstructed images.
△ Less
Submitted 3 June, 2025;
originally announced June 2025.
-
Simulation-Based Inference for Adaptive Experiments
Authors:
Brian M Cho,
Aurélien Bibaut,
Nathan Kallus
Abstract:
Multi-arm bandit experimental designs are increasingly being adopted over standard randomized trials due to their potential to improve outcomes for study participants, enable faster identification of the best-performing options, and/or enhance the precision of estimating key parameters. Current approaches for inference after adaptive sampling either rely on asymptotic normality under restricted ex…
▽ More
Multi-arm bandit experimental designs are increasingly being adopted over standard randomized trials due to their potential to improve outcomes for study participants, enable faster identification of the best-performing options, and/or enhance the precision of estimating key parameters. Current approaches for inference after adaptive sampling either rely on asymptotic normality under restricted experiment designs or underpowered martingale concentration inequalities that lead to weak power in practice. To bypass these limitations, we propose a simulation-based approach for conducting hypothesis tests and constructing confidence intervals for arm specific means and their differences. Our simulation-based approach uses positively biased nuisances to generate additional trajectories of the experiment, which we call \textit{simulation with optimism}. Using these simulations, we characterize the distribution potentially non-normal sample mean test statistic to conduct inference. We provide guarantees for (i) asymptotic type I error control, (ii) convergence of our confidence intervals, and (iii) asymptotic strong consistency of our estimator over a wide variety of common bandit designs. Our empirical results show that our approach achieves the desired coverage while reducing confidence interval widths by up to 50%, with drastic improvements for arms not targeted by the design.
△ Less
Submitted 3 June, 2025;
originally announced June 2025.
-
Temperature is All You Need for Generalization in Langevin Dynamics and other Markov Processes
Authors:
Itamar Harel,
Yonathan Wolanowsky,
Gal Vardi,
Nathan Srebro,
Daniel Soudry
Abstract:
We analyze the generalization gap (gap between the training and test errors) when training a potentially over-parametrized model using a Markovian stochastic training algorithm, initialized from some distribution $θ_0 \sim p_0$. We focus on Langevin dynamics with a positive temperature $β^{-1}$, i.e. gradient descent on a training loss $L$ with infinitesimal step size, perturbed with $β^{-1}$-vari…
▽ More
We analyze the generalization gap (gap between the training and test errors) when training a potentially over-parametrized model using a Markovian stochastic training algorithm, initialized from some distribution $θ_0 \sim p_0$. We focus on Langevin dynamics with a positive temperature $β^{-1}$, i.e. gradient descent on a training loss $L$ with infinitesimal step size, perturbed with $β^{-1}$-variances Gaussian noise, and lightly regularized or bounded. There, we bound the generalization gap, at any time during training, by $\sqrt{(β\mathbb{E} L (θ_0) + \log(1/δ))/N}$ with probability $1-δ$ over the dataset, where $N$ is the sample size, and $\mathbb{E} L (θ_0) =O(1)$ with standard initialization scaling. In contrast to previous guarantees, we have no dependence on either training time or reliance on mixing, nor a dependence on dimensionality, gradient norms, or any other properties of the loss or model. This guarantee follows from a general analysis of any Markov process-based training that has a Gibbs-style stationary distribution. The proof is surprisingly simple, once we observe that the marginal distribution divergence from initialization remains bounded, as implied by a generalized second law of thermodynamics.
△ Less
Submitted 25 May, 2025;
originally announced May 2025.
-
Efficient Adaptive Experimentation with Non-Compliance
Authors:
Miruna Oprescu,
Brian M Cho,
Nathan Kallus
Abstract:
We study the problem of estimating the average treatment effect (ATE) in adaptive experiments where treatment can only be encouraged--rather than directly assigned--via a binary instrumental variable. Building on semiparametric efficiency theory, we derive the efficiency bound for ATE estimation under arbitrary, history-dependent instrument-assignment policies, and show it is minimized by a varian…
▽ More
We study the problem of estimating the average treatment effect (ATE) in adaptive experiments where treatment can only be encouraged--rather than directly assigned--via a binary instrumental variable. Building on semiparametric efficiency theory, we derive the efficiency bound for ATE estimation under arbitrary, history-dependent instrument-assignment policies, and show it is minimized by a variance-aware allocation rule that balances outcome noise and compliance variability. Leveraging this insight, we introduce AMRIV--an \textbf{A}daptive, \textbf{M}ultiply-\textbf{R}obust estimator for \textbf{I}nstrumental-\textbf{V}ariable settings with variance-optimal assignment. AMRIV pairs (i) an online policy that adaptively approximates the optimal allocation with (ii) a sequential, influence-function-based estimator that attains the semiparametric efficiency bound while retaining multiply-robust consistency. We establish asymptotic normality, explicit convergence rates, and anytime-valid asymptotic confidence sequences that enable sequential inference. Finally, we demonstrate the practical effectiveness of our approach through empirical studies, showing that adaptive instrument assignment, when combined with the AMRIV estimator, yields improved efficiency and robustness compared to existing baselines.
△ Less
Submitted 23 May, 2025;
originally announced May 2025.
-
Contextual Phenotyping of Pediatric Sepsis Cohort Using Large Language Models
Authors:
Aditya Nagori,
Ayush Gautam,
Matthew O. Wiens,
Vuong Nguyen,
Nathan Kenya Mugisha,
Jerome Kabakyenga,
Niranjan Kissoon,
John Mark Ansermino,
Rishikesan Kamaleswaran
Abstract:
Clustering patient subgroups is essential for personalized care and efficient resource use. Traditional clustering methods struggle with high-dimensional, heterogeneous healthcare data and lack contextual understanding. This study evaluates Large Language Model (LLM) based clustering against classical methods using a pediatric sepsis dataset from a low-income country (LIC), containing 2,686 record…
▽ More
Clustering patient subgroups is essential for personalized care and efficient resource use. Traditional clustering methods struggle with high-dimensional, heterogeneous healthcare data and lack contextual understanding. This study evaluates Large Language Model (LLM) based clustering against classical methods using a pediatric sepsis dataset from a low-income country (LIC), containing 2,686 records with 28 numerical and 119 categorical variables. Patient records were serialized into text with and without a clustering objective. Embeddings were generated using quantized LLAMA 3.1 8B, DeepSeek-R1-Distill-Llama-8B with low-rank adaptation(LoRA), and Stella-En-400M-V5 models. K-means clustering was applied to these embeddings. Classical comparisons included K-Medoids clustering on UMAP and FAMD-reduced mixed data. Silhouette scores and statistical tests evaluated cluster quality and distinctiveness. Stella-En-400M-V5 achieved the highest Silhouette Score (0.86). LLAMA 3.1 8B with the clustering objective performed better with higher number of clusters, identifying subgroups with distinct nutritional, clinical, and socioeconomic profiles. LLM-based methods outperformed classical techniques by capturing richer context and prioritizing key features. These results highlight potential of LLMs for contextual phenotyping and informed decision-making in resource-limited settings.
△ Less
Submitted 14 May, 2025;
originally announced May 2025.
-
LatticeVision: Image to Image Networks for Modeling Non-Stationary Spatial Data
Authors:
Antony Sikorski,
Michael Ivanitskiy,
Nathan Lenssen,
Douglas Nychka,
Daniel McKenzie
Abstract:
In many scientific and industrial applications, we are given a handful of instances (a 'small ensemble') of a spatially distributed quantity (a 'field') but would like to acquire many more. For example, a large ensemble of global temperature sensitivity fields from a climate model can help farmers, insurers, and governments plan appropriately. When acquiring more data is prohibitively expensive --…
▽ More
In many scientific and industrial applications, we are given a handful of instances (a 'small ensemble') of a spatially distributed quantity (a 'field') but would like to acquire many more. For example, a large ensemble of global temperature sensitivity fields from a climate model can help farmers, insurers, and governments plan appropriately. When acquiring more data is prohibitively expensive -- as is the case with climate models -- statistical emulation offers an efficient alternative for simulating synthetic yet realistic fields. However, parameter inference using maximum likelihood estimation (MLE) is computationally prohibitive, especially for large, non-stationary fields. Thus, many recent works train neural networks to estimate parameters given spatial fields as input, sidestepping MLE completely. In this work we focus on a popular class of parametric, spatially autoregressive (SAR) models. We make a simple yet impactful observation; because the SAR parameters can be arranged on a regular grid, both inputs (spatial fields) and outputs (model parameters) can be viewed as images. Using this insight, we demonstrate that image-to-image (I2I) networks enable faster and more accurate parameter estimation for a class of non-stationary SAR models with unprecedented complexity.
△ Less
Submitted 14 May, 2025;
originally announced May 2025.
-
Lower Bounds on the MMSE of Adversarially Inferring Sensitive Features
Authors:
Monica Welfert,
Nathan Stromberg,
Mario Diaz,
Lalitha Sankar
Abstract:
We propose an adversarial evaluation framework for sensitive feature inference based on minimum mean-squared error (MMSE) estimation with a finite sample size and linear predictive models. Our approach establishes theoretical lower bounds on the true MMSE of inferring sensitive features from noisy observations of other correlated features. These bounds are expressed in terms of the empirical MMSE…
▽ More
We propose an adversarial evaluation framework for sensitive feature inference based on minimum mean-squared error (MMSE) estimation with a finite sample size and linear predictive models. Our approach establishes theoretical lower bounds on the true MMSE of inferring sensitive features from noisy observations of other correlated features. These bounds are expressed in terms of the empirical MMSE under a restricted hypothesis class and a non-negative error term. The error term captures both the estimation error due to finite number of samples and the approximation error from using a restricted hypothesis class. For linear predictive models, we derive closed-form bounds, which are order optimal in terms of the noise variance, on the approximation error for several classes of relationships between the sensitive and non-sensitive features, including linear mappings, binary symmetric channels, and class-conditional multi-variate Gaussian distributions. We also present a new lower bound that relies on the MSE computed on a hold-out validation dataset of the MMSE estimator learned on finite-samples and a restricted hypothesis class. Through empirical evaluation, we demonstrate that our framework serves as an effective tool for MMSE-based adversarial evaluation of sensitive feature inference that balances theoretical guarantees with practical efficiency.
△ Less
Submitted 13 May, 2025;
originally announced May 2025.
-
Nonparametric Instrumental Variable Inference with Many Weak Instruments
Authors:
Lars van der Laan,
Nathan Kallus,
Aurélien Bibaut
Abstract:
We study inference on linear functionals in the nonparametric instrumental variable (NPIV) problem with a discretely-valued instrument under a many-weak-instruments asymptotic regime, where the number of instrument values grows with the sample size. A key motivating example is estimating long-term causal effects in a new experiment with only short-term outcomes, using past experiments to instrumen…
▽ More
We study inference on linear functionals in the nonparametric instrumental variable (NPIV) problem with a discretely-valued instrument under a many-weak-instruments asymptotic regime, where the number of instrument values grows with the sample size. A key motivating example is estimating long-term causal effects in a new experiment with only short-term outcomes, using past experiments to instrument for the effect of short- on long-term outcomes. Here, the assignment to a past experiment serves as the instrument: we have many past experiments but only a limited number of units in each. Since the structural function is nonparametric but constrained by only finitely many moment restrictions, point identification typically fails. To address this, we consider linear functionals of the minimum-norm solution to the moment restrictions, which is always well-defined. As the number of instrument levels grows, these functionals define an approximating sequence to a target functional, replacing point identification with a weaker asymptotic notion suited to discrete instruments. Extending the Jackknife Instrumental Variable Estimator (JIVE) beyond the classical parametric setting, we propose npJIVE, a nonparametric estimator for solutions to linear inverse problems with many weak instruments. We construct automatic debiased machine learning estimators for linear functionals of both the structural function and its minimum-norm projection, and establish their efficiency in the many-weak-instruments regime.
△ Less
Submitted 12 May, 2025;
originally announced May 2025.
-
Kernel Dynamic Mode Decomposition For Sparse Reconstruction of Closable Koopman Operators
Authors:
Nishant Panda,
Himanshu Singh,
J. Nathan Kutz
Abstract:
Spatial temporal reconstruction of dynamical system is indeed a crucial problem with diverse applications ranging from climate modeling to numerous chaotic and physical processes. These reconstructions are based on the harmonious relationship between the Koopman operators and the choice of dictionary, determined implicitly by a kernel function. This leads to the approximation of the Koopman operat…
▽ More
Spatial temporal reconstruction of dynamical system is indeed a crucial problem with diverse applications ranging from climate modeling to numerous chaotic and physical processes. These reconstructions are based on the harmonious relationship between the Koopman operators and the choice of dictionary, determined implicitly by a kernel function. This leads to the approximation of the Koopman operators in a reproducing kernel Hilbert space (RKHS) associated with that kernel function. Data-driven analysis of Koopman operators demands that Koopman operators be closable over the underlying RKHS, which still remains an unsettled, unexplored, and critical operator-theoretic challenge. We aim to address this challenge by investigating the embedding of the Laplacian kernel in the measure-theoretic sense, giving rise to a rich enough RKHS to settle the closability of the Koopman operators. We leverage Kernel Extended Dynamic Mode Decomposition with the Laplacian kernel to reconstruct the dominant spatial temporal modes of various diverse dynamical systems. After empirical demonstration, we concrete such results by providing the theoretical justification leveraging the closability of the Koopman operators on the RKHS generated by the Laplacian kernel on the avenues of Koopman mode decomposition and the Koopman spectral measure. Such results were explored from both grounds of operator theory and data-driven science, thus making the Laplacian kernel a robust choice for spatial-temporal reconstruction.
△ Less
Submitted 10 May, 2025;
originally announced May 2025.
-
Mixed-Integer Optimization for Responsible Machine Learning
Authors:
Nathan Justin,
Qingshi Sun,
Andrés Gómez,
Phebe Vayanos
Abstract:
In the last few decades, Machine Learning (ML) has achieved significant success across domains ranging from healthcare, sustainability, and the social sciences, to criminal justice and finance. But its deployment in increasingly sophisticated, critical, and sensitive areas affecting individuals, the groups they belong to, and society as a whole raises critical concerns around fairness, transparenc…
▽ More
In the last few decades, Machine Learning (ML) has achieved significant success across domains ranging from healthcare, sustainability, and the social sciences, to criminal justice and finance. But its deployment in increasingly sophisticated, critical, and sensitive areas affecting individuals, the groups they belong to, and society as a whole raises critical concerns around fairness, transparency, robustness, and privacy, among others. As the complexity and scale of ML systems and of the settings in which they are deployed grow, so does the need for responsible ML methods that address these challenges while providing guaranteed performance in deployment.
Mixed-integer optimization (MIO) offers a powerful framework for embedding responsible ML considerations directly into the learning process while maintaining performance. For example, it enables learning of inherently transparent models that can conveniently incorporate fairness or other domain specific constraints. This tutorial paper provides an accessible and comprehensive introduction to this topic discussing both theoretical and practical aspects. It outlines some of the core principles of responsible ML, their importance in applications, and the practical utility of MIO for building ML models that align with these principles. Through examples and mathematical formulations, it illustrates practical strategies and available tools for efficiently solving MIO problems for responsible ML. It concludes with a discussion on current limitations and open research questions, providing suggestions for future work.
△ Less
Submitted 9 May, 2025;
originally announced May 2025.
-
Multilevel Sampling in Algebraic Statistics
Authors:
Nathan Kirk,
Ivan Gvozdanović,
Sonja Petrović
Abstract:
This paper proposes a multilevel sampling algorithm for fiber sampling problems in algebraic statistics, inspired by Henry Wynn's suggestion to adapt multilevel Monte Carlo (MLMC) ideas to discrete models. Focusing on log-linear models, we sample from high-dimensional lattice fibers defined by algebraic constraints. Building on Markov basis methods and results from Diaconis and Sturmfels, our algo…
▽ More
This paper proposes a multilevel sampling algorithm for fiber sampling problems in algebraic statistics, inspired by Henry Wynn's suggestion to adapt multilevel Monte Carlo (MLMC) ideas to discrete models. Focusing on log-linear models, we sample from high-dimensional lattice fibers defined by algebraic constraints. Building on Markov basis methods and results from Diaconis and Sturmfels, our algorithm uses variable step sizes to accelerate exploration and reduce the need for long burn-in. We introduce a novel Fiber Coverage Score (FCS) based on Voronoi partitioning to assess sample quality, and highlight the utility of the Maximum Mean Discrepancy (MMD) quality metric. Simulations on benchmark fibers show that multilevel sampling outperforms naive MCMC approaches. Our results demonstrate that multilevel methods, when properly applied, provide practical benefits for discrete sampling in algebraic statistics.
△ Less
Submitted 6 May, 2025;
originally announced May 2025.
-
Multi-site modelling and reconstruction of past extreme skew surges along the French Atlantic coast
Authors:
Nathan Huet,
Philippe Naveau,
Anne Sabourin
Abstract:
Appropriate modelling of extreme skew surges is crucial, particularly for coastal risk management. Our study focuses on modelling extreme skew surges along the French Atlantic coast, with a particular emphasis on investigating the extremal dependence structure between stations. We employ the peak-over-threshold framework, where a multivariate extreme event is defined whenever at least one location…
▽ More
Appropriate modelling of extreme skew surges is crucial, particularly for coastal risk management. Our study focuses on modelling extreme skew surges along the French Atlantic coast, with a particular emphasis on investigating the extremal dependence structure between stations. We employ the peak-over-threshold framework, where a multivariate extreme event is defined whenever at least one location records a large value, though not necessarily all stations simultaneously. A novel method for determining an appropriate level (threshold) above which observations can be classified as extreme is proposed. Two complementary approaches are explored. First, the multivariate generalized Pareto distribution is employed to model extremes, leveraging its properties to derive a generative model that predicts extreme skew surges at one station based on observed extremes at nearby stations. Second, a novel extreme regression framework is assessed for point predictions. This specific regression framework enables accurate point predictions using only the "angle" of input variables, i.e. input variables divided by their norms. The ultimate objective is to reconstruct historical skew surge time series at stations with limited data. This is achieved by integrating extreme skew surge data from stations with longer records, such as Brest and Saint-Nazaire, which provide over 150 years of observations.
△ Less
Submitted 1 May, 2025;
originally announced May 2025.
-
From Continual Learning to SGD and Back: Better Rates for Continual Linear Models
Authors:
Itay Evron,
Ran Levinstein,
Matan Schliserman,
Uri Sherman,
Tomer Koren,
Daniel Soudry,
Nathan Srebro
Abstract:
We theoretically study the common continual learning setup where an overparameterized model is sequentially fitted to a set of jointly realizable tasks. We analyze the forgetting, i.e., loss on previously seen tasks, after $k$ iterations. For continual linear models, we prove that fitting a task is equivalent to a single stochastic gradient descent (SGD) step on a modified objective. We develop no…
▽ More
We theoretically study the common continual learning setup where an overparameterized model is sequentially fitted to a set of jointly realizable tasks. We analyze the forgetting, i.e., loss on previously seen tasks, after $k$ iterations. For continual linear models, we prove that fitting a task is equivalent to a single stochastic gradient descent (SGD) step on a modified objective. We develop novel last-iterate SGD upper bounds in the realizable least squares setup, which we then leverage to derive new results for continual learning. Focusing on random orderings over $T$ tasks, we establish universal forgetting rates, whereas existing rates depend on the problem dimensionality or complexity. Specifically, in continual regression with replacement, we improve the best existing rate from $O((d-r)/k)$ to $O(\min(k^{-1/4}, \sqrt{d-r}/k, \sqrt{Tr}/k))$, where $d$ is the dimensionality and $r$ the average task rank. Furthermore, we establish the first rate for random task orderings without replacement. The obtained rate of $O(\min(T^{-1/4}, (d-r)/T))$ proves for the first time that randomization alone, with no task repetition, can prevent catastrophic forgetting in sufficiently long task sequences. Finally, we prove a matching $O(k^{-1/4})$ forgetting rate for continual linear classification on separable data. Our universal rates apply for broader projection methods, such as block Kaczmarz and POCS, illuminating their loss convergence under i.i.d. and one-pass orderings.
△ Less
Submitted 27 May, 2025; v1 submitted 6 April, 2025;
originally announced April 2025.
-
Stochastic ecohydrological perspective on semi-distributed rainfall-runoff dynamics
Authors:
Mark S. Bartlett,
Elizabeth Cultra,
Nathan Geldner,
Amilcare Porporato
Abstract:
Quantifying watershed process variability consistently with climate change and ecohydrological dynamics remains a central challenge in hydrology. Stochastic ecohydrology characterizes hydrologic variability through probability distributions that link climate, hydrology, and ecology. However, these approaches are often limited to small spatial scales (e.g., point or plot level) or focus on specific…
▽ More
Quantifying watershed process variability consistently with climate change and ecohydrological dynamics remains a central challenge in hydrology. Stochastic ecohydrology characterizes hydrologic variability through probability distributions that link climate, hydrology, and ecology. However, these approaches are often limited to small spatial scales (e.g., point or plot level) or focus on specific fluxes (e.g., streamflow), without accounting for the entire water balance at the basin scale. While semi-distributed models account for spatial heterogeneity and upscaled hydrologic fluxes, they lack the analytical simplicity of stochastic ecohydrology or the SCS-CN method and, perhaps more importantly, do not integrate the effects of past random variability in hydroclimatic conditions. This hinders an efficient characterization of hydrological statistics at the watershed scale. To overcome these limitations, we merge stochastic ecohydrology, the spatial upscaling of semi-distributed modeling, and the SCS-CN rainfall-runoff partitioning. The resulting unified model analytically characterizes watershed ecohydrological and hydrological statistics using probability density functions (PDFs) that are functions of climate and watershed attributes -- something unattainable with the Monte Carlo methods of traditional stochastic hydrology. Calibrated across 81 watersheds in Florida and southern Louisiana, the model PDFs precisely capture the long-term average water balance and runoff variance, as well as the runoff quantiles with a median normalized Nash-Sutcliffe (NNSE) efficiency of 0.95. These results also advance the SCS-CN method by providing an analytical PDF for the Curve Number (CN), explicitly linked to climate variables, baseflow, and the long-term water balance partitioning described by the Budyko curve.
△ Less
Submitted 24 March, 2025;
originally announced March 2025.
-
SNPL: Simultaneous Policy Learning and Evaluation for Safe Multi-Objective Policy Improvement
Authors:
Brian Cho,
Ana-Roxana Pop,
Ariel Evnine,
Nathan Kallus
Abstract:
To design effective digital interventions, experimenters face the challenge of learning decision policies that balance multiple objectives using offline data. Often, they aim to develop policies that maximize goal outcomes, while ensuring there are no undesirable changes in guardrail outcomes. To provide credible recommendations, experimenters must not only identify policies that satisfy the desir…
▽ More
To design effective digital interventions, experimenters face the challenge of learning decision policies that balance multiple objectives using offline data. Often, they aim to develop policies that maximize goal outcomes, while ensuring there are no undesirable changes in guardrail outcomes. To provide credible recommendations, experimenters must not only identify policies that satisfy the desired changes in goal and guardrail outcomes, but also offer probabilistic guarantees about the changes these policies induce. In practice, however, policy classes are often large, and digital experiments tend to produce datasets with small effect sizes relative to noise. In this setting, standard approaches such as data splitting or multiple testing often result in unstable policy selection and/or insufficient statistical power. In this paper, we provide safe noisy policy learning (SNPL), a novel approach that leverages the concept of algorithmic stability to address these challenges. Our method enables policy learning while simultaneously providing high-confidence guarantees using the entire dataset, avoiding the need for data-splitting. We present finite-sample and asymptotic versions of our algorithm that ensure the recommended policy satisfies high-probability guarantees for avoiding guardrail regressions and/or achieving goal outcome improvements. We test both variants of our approach approach empirically on a real-world application of personalizing SMS delivery. Our results on real-world data suggest that our approach offers dramatic improvements in settings with large policy classes and low signal-to-noise across both finite-sample and asymptotic safety guarantees, offering up to 300\% improvements in detection rates and 150\% improvements in policy gains at significantly smaller sample sizes.
△ Less
Submitted 21 March, 2025; v1 submitted 16 March, 2025;
originally announced March 2025.
-
Mixed-feature Logistic Regression Robust to Distribution Shifts
Authors:
Qingshi Sun,
Nathan Justin,
Andres Gomez,
Phebe Vayanos
Abstract:
Logistic regression models are widely used in the social and behavioral sciences and in high-stakes domains, due to their simplicity and interpretability properties. At the same time, such domains are permeated by distribution shifts, where the distribution generating the data changes between training and deployment. In this paper, we study a distributionally robust logistic regression problem tha…
▽ More
Logistic regression models are widely used in the social and behavioral sciences and in high-stakes domains, due to their simplicity and interpretability properties. At the same time, such domains are permeated by distribution shifts, where the distribution generating the data changes between training and deployment. In this paper, we study a distributionally robust logistic regression problem that seeks the model that will perform best against adversarial realizations of the data distribution drawn from a suitably constructed Wasserstein ambiguity set. Our model and solution approach differ from prior work in that we can capture settings where the likelihood of distribution shifts can vary across features, significantly broadening the applicability of our model relative to the state-of-the-art. We propose a graph-based solution approach that can be integrated into off-the-shelf optimization solvers. We evaluate the performance of our model and algorithms on numerous publicly available datasets. Our solution achieves a 408x speed-up relative to the state-of-the-art. Additionally, compared to the state-of-the-art, our model reduces average calibration error by up to 36.19% and worst-case calibration error by up to 41.70%, while increasing the average area under the ROC curve (AUC) by up to 18.02% and worst-case AUC by up to 48.37%.
△ Less
Submitted 15 March, 2025;
originally announced March 2025.
-
A Theory of Learning with Autoregressive Chain of Thought
Authors:
Nirmit Joshi,
Gal Vardi,
Adam Block,
Surbhi Goel,
Zhiyuan Li,
Theodor Misiakiewicz,
Nathan Srebro
Abstract:
For a given base class of sequence-to-next-token generators, we consider learning prompt-to-answer mappings obtained by iterating a fixed, time-invariant generator for multiple steps, thus generating a chain-of-thought, and then taking the final token as the answer. We formalize the learning problems both when the chain-of-thought is observed and when training only on prompt-answer pairs, with the…
▽ More
For a given base class of sequence-to-next-token generators, we consider learning prompt-to-answer mappings obtained by iterating a fixed, time-invariant generator for multiple steps, thus generating a chain-of-thought, and then taking the final token as the answer. We formalize the learning problems both when the chain-of-thought is observed and when training only on prompt-answer pairs, with the chain-of-thought latent. We analyze the sample and computational complexity both in terms of general properties of the base class (e.g. its VC dimension) and for specific base classes such as linear thresholds. We present a simple base class that allows for universal representability and computationally tractable chain-of-thought learning. Central to our development is that time invariance allows for sample complexity that is independent of the length of the chain-of-thought. Attention arises naturally in our construction.
△ Less
Submitted 10 March, 2025;
originally announced March 2025.
-
Interactive Visualization Framework for Forensic Bullet Comparisons
Authors:
Nathan Rethwisch,
Heike Hofmann
Abstract:
The current method for forensic analysis of bullet comparison relies on manual examination by forensic examiners to determine if bullets were discharged from the same firearm. This process is highly subjective, prompting the development of algorithmic methods to provide objective statistical support for comparisons. However, a gap exists between the technical understanding of these algorithms and…
▽ More
The current method for forensic analysis of bullet comparison relies on manual examination by forensic examiners to determine if bullets were discharged from the same firearm. This process is highly subjective, prompting the development of algorithmic methods to provide objective statistical support for comparisons. However, a gap exists between the technical understanding of these algorithms and the typical background of many forensic examiners. We present a visualization tool designed to bridge this gap, allowing for the presentation of statistical information in a more familiar format to forensic professionals. The forensic bullet comparison visualizer (FBCV) features a variety of plots that will enable the user to examine every step of the algorithmic comparison process. We demonstrate the utility of the FBCV by applying it to data from the Houston Science Lab, where it helped identify an error in the comparison process caused by mislabeling. This tool can be used for future investigations, such as examining how distance between shots affects scores. The FBCV offers a user-friendly way to convey complex statistical information to forensic examiners, facilitating their understanding and utilization of algorithmic comparison methods.
△ Less
Submitted 7 March, 2025;
originally announced March 2025.
-
Weak-to-Strong Generalization Even in Random Feature Networks, Provably
Authors:
Marko Medvedev,
Kaifeng Lyu,
Dingli Yu,
Sanjeev Arora,
Zhiyuan Li,
Nathan Srebro
Abstract:
Weak-to-Strong Generalization (Burns et al., 2024) is the phenomenon whereby a strong student, say GPT-4, learns a task from a weak teacher, say GPT-2, and ends up significantly outperforming the teacher. We show that this phenomenon does not require a strong learner like GPT-4. We consider student and teacher that are random feature models, described by two-layer networks with a random and fixed…
▽ More
Weak-to-Strong Generalization (Burns et al., 2024) is the phenomenon whereby a strong student, say GPT-4, learns a task from a weak teacher, say GPT-2, and ends up significantly outperforming the teacher. We show that this phenomenon does not require a strong learner like GPT-4. We consider student and teacher that are random feature models, described by two-layer networks with a random and fixed bottom layer and a trained top layer. A "weak" teacher, with a small number of units (i.e. random features), is trained on the population, and a "strong" student, with a much larger number of units (i.e. random features), is trained only on labels generated by the weak teacher. We demonstrate, prove, and understand how the student can outperform the teacher, even though trained only on data labeled by the teacher. We also explain how such weak-to-strong generalization is enabled by early stopping. Importantly, we also show the quantitative limits of weak-to-strong generalization in this model.
△ Less
Submitted 4 March, 2025;
originally announced March 2025.
-
Quantifying Overfitting along the Regularization Path for Two-Part-Code MDL in Supervised Classification
Authors:
Xiaohan Zhu,
Nathan Srebro
Abstract:
We provide a complete characterization of the entire regularization curve of a modified two-part-code Minimum Description Length (MDL) learning rule for binary classification, based on an arbitrary prior or description language. Grunwald and Langford [2004] previously established the lack of asymptotic consistency, from an agnostic PAC (frequentist worst case) perspective, of the MDL rule with a p…
▽ More
We provide a complete characterization of the entire regularization curve of a modified two-part-code Minimum Description Length (MDL) learning rule for binary classification, based on an arbitrary prior or description language. Grunwald and Langford [2004] previously established the lack of asymptotic consistency, from an agnostic PAC (frequentist worst case) perspective, of the MDL rule with a penalty parameter of $λ=1$, suggesting that it underegularizes. Driven by interest in understanding how benign or catastrophic under-regularization and overfitting might be, we obtain a precise quantitative description of the worst case limiting error as a function of the regularization parameter $λ$ and noise level (or approximation error), significantly tightening the analysis of Grunwald and Langford for $λ=1$ and extending it to all other choices of $λ$.
△ Less
Submitted 10 March, 2025; v1 submitted 3 March, 2025;
originally announced March 2025.
-
Tight Bounds on the Binomial CDF, and the Minimum of i.i.d Binomials, in terms of KL-Divergence
Authors:
Xiaohan Zhu,
Mesrob I. Ohannessian,
Nathan Srebro
Abstract:
We provide finite sample upper and lower bounds on the Binomial tail probability which are a direct application of Sanov's theorem. We then use these to obtain high probability upper and lower bounds on the minimum of i.i.d. Binomial random variables. Both bounds are finite sample, asymptotically tight, and expressed in terms of the KL-divergence.
We provide finite sample upper and lower bounds on the Binomial tail probability which are a direct application of Sanov's theorem. We then use these to obtain high probability upper and lower bounds on the minimum of i.i.d. Binomial random variables. Both bounds are finite sample, asymptotically tight, and expressed in terms of the KL-divergence.
△ Less
Submitted 25 February, 2025;
originally announced February 2025.
-
Forecasting time series with constraints
Authors:
Nathan Doumèche,
Francis Bach,
Éloi Bedek,
Gérard Biau,
Claire Boyer,
Yannig Goude
Abstract:
Time series forecasting presents unique challenges that limit the effectiveness of traditional machine learning algorithms. To address these limitations, various approaches have incorporated linear constraints into learning algorithms, such as generalized additive models and hierarchical forecasting. In this paper, we propose a unified framework for integrating and combining linear constraints in…
▽ More
Time series forecasting presents unique challenges that limit the effectiveness of traditional machine learning algorithms. To address these limitations, various approaches have incorporated linear constraints into learning algorithms, such as generalized additive models and hierarchical forecasting. In this paper, we propose a unified framework for integrating and combining linear constraints in time series forecasting. Within this framework, we show that the exact minimizer of the constrained empirical risk can be computed efficiently using linear algebra alone. This approach allows for highly scalable implementations optimized for GPUs. We validate the proposed methodology through extensive benchmarking on real-world tasks, including electricity demand forecasting and tourism forecasting, achieving state-of-the-art performance.
△ Less
Submitted 14 February, 2025;
originally announced February 2025.
-
Evaluating Decision Rules Across Many Weak Experiments
Authors:
Winston Chou,
Colin Gray,
Nathan Kallus,
Aurélien Bibaut,
Simon Ejdemyr
Abstract:
Technology firms conduct randomized controlled experiments ("A/B tests") to learn which actions to take to improve business outcomes. In firms with mature experimentation platforms, experimentation programs can consist of many thousands of tests. To effectively scale experimentation, firms rely on decision rules: standard operating procedures for mapping the results of an experiment to a choice of…
▽ More
Technology firms conduct randomized controlled experiments ("A/B tests") to learn which actions to take to improve business outcomes. In firms with mature experimentation platforms, experimentation programs can consist of many thousands of tests. To effectively scale experimentation, firms rely on decision rules: standard operating procedures for mapping the results of an experiment to a choice of treatment arm to launch to the general user population. Despite the critical role of decision rules in translating experimentation into business decisions, rigorous guidance on how to evaluate and choose decision rules is scarce. This paper proposes to evaluate decision rules based on their cumulative returns to business north star metrics. Although intuitive and easy to explain to decision-makers, this quantity can be difficult to estimate, especially when experiments have weak signal-to-noise ratios. We develop a cross-validation estimator that is much less biased than the naive plug-in estimator under conditions realistic to digital experimentation. We demonstrate the efficacy of our approach via a case study of 123 historical A/B tests at Netflix, where we used it to show that a new decision rule would have increased cumulative returns to the north star metric by an estimated $33\%$, directly leading to the adoption of the new rule.
△ Less
Submitted 29 May, 2025; v1 submitted 12 February, 2025;
originally announced February 2025.
-
Uncertainty-Aware Adaptation of Large Language Models for Protein-Protein Interaction Analysis
Authors:
Sanket Jantre,
Tianle Wang,
Gilchan Park,
Kriti Chopra,
Nicholas Jeon,
Xiaoning Qian,
Nathan M. Urban,
Byung-Jun Yoon
Abstract:
Identification of protein-protein interactions (PPIs) helps derive cellular mechanistic understanding, particularly in the context of complex conditions such as neurodegenerative disorders, metabolic syndromes, and cancer. Large Language Models (LLMs) have demonstrated remarkable potential in predicting protein structures and interactions via automated mining of vast biomedical literature; yet the…
▽ More
Identification of protein-protein interactions (PPIs) helps derive cellular mechanistic understanding, particularly in the context of complex conditions such as neurodegenerative disorders, metabolic syndromes, and cancer. Large Language Models (LLMs) have demonstrated remarkable potential in predicting protein structures and interactions via automated mining of vast biomedical literature; yet their inherent uncertainty remains a key challenge for deriving reproducible findings, critical for biomedical applications. In this study, we present an uncertainty-aware adaptation of LLMs for PPI analysis, leveraging fine-tuned LLaMA-3 and BioMedGPT models. To enhance prediction reliability, we integrate LoRA ensembles and Bayesian LoRA models for uncertainty quantification (UQ), ensuring confidence-calibrated insights into protein behavior. Our approach achieves competitive performance in PPI identification across diverse disease contexts while addressing model uncertainty, thereby enhancing trustworthiness and reproducibility in computational biology. These findings underscore the potential of uncertainty-aware LLM adaptation for advancing precision medicine and biomedical research.
△ Less
Submitted 10 February, 2025;
originally announced February 2025.
-
GST-UNet: Spatiotemporal Causal Inference with Time-Varying Confounders
Authors:
Miruna Oprescu,
David K. Park,
Xihaier Luo,
Shinjae Yoo,
Nathan Kallus
Abstract:
Estimating causal effects from spatiotemporal data is a key challenge in fields such as public health, social policy, and environmental science, where controlled experiments are often infeasible. However, existing causal inference methods relying on observational data face significant limitations: they depend on strong structural assumptions to address spatiotemporal challenges $\unicode{x2013}$ s…
▽ More
Estimating causal effects from spatiotemporal data is a key challenge in fields such as public health, social policy, and environmental science, where controlled experiments are often infeasible. However, existing causal inference methods relying on observational data face significant limitations: they depend on strong structural assumptions to address spatiotemporal challenges $\unicode{x2013}$ such as interference, spatial confounding, and temporal carryover effects $\unicode{x2013}$ or fail to account for $\textit{time-varying confounders}$. These confounders, influenced by past treatments and outcomes, can themselves shape future treatments and outcomes, creating feedback loops that complicate traditional adjustment strategies. To address these challenges, we introduce the $\textbf{GST-UNet}$ ($\textbf{G}$-computation $\textbf{S}$patio-$\textbf{T}$emporal $\textbf{UNet}$), a novel end-to-end neural network framework designed to estimate treatment effects in complex spatial and temporal settings. The GST-UNet leverages regression-based iterative G-computation to explicitly adjust for time-varying confounders, providing valid estimates of potential outcomes and treatment effects. To the best of our knowledge, the GST-UNet is the first neural model to account for complex, non-linear dynamics and time-varying confounders in spatiotemporal interventions. We demonstrate the effectiveness of the GST-UNet through extensive simulation studies and showcase its practical utility with a real-world analysis of the impact of wildfire smoke on respiratory hospitalizations during the 2018 California Camp Fire. Our results highlight the potential of GST-UNet to advance spatiotemporal causal inference across a wide range of policy-driven and scientific applications.
△ Less
Submitted 7 February, 2025;
originally announced February 2025.
-
Quasi-Monte Carlo Methods: What, Why, and How?
Authors:
Fred J. Hickernell,
Nathan Kirk,
Aleksei G. Sorokin
Abstract:
Many questions in quantitative finance, uncertainty quantification, and other disciplines are answered by computing the population mean, $μ:= \mathbb{E}(Y)$, where instances of $Y:=f(\boldsymbol{X})$ may be generated by numerical simulation and $\boldsymbol{X}$ has a simple probability distribution. The population mean can be approximated by the sample mean,…
▽ More
Many questions in quantitative finance, uncertainty quantification, and other disciplines are answered by computing the population mean, $μ:= \mathbb{E}(Y)$, where instances of $Y:=f(\boldsymbol{X})$ may be generated by numerical simulation and $\boldsymbol{X}$ has a simple probability distribution. The population mean can be approximated by the sample mean, $\hatμ_n := n^{-1} \sum_{i=0}^{n-1} f(\boldsymbol{x}_i)$ for a well chosen sequence of nodes, $\{\boldsymbol{x}_0, \boldsymbol{x}_1, \ldots\}$ and a sufficiently large sample size, $n$. Computing $μ$ is equivalent to computing a $d$-dimensional integral, $\int f(\boldsymbol{x}) \varrho(\boldsymbol{x}) \, \mathrm{d} \boldsymbol{x}$, where $\varrho$ is the probability density for $\boldsymbol{X}$.
Quasi-Monte Carlo methods replace independent and identically distributed sequences of random vector nodes, $\{\boldsymbol{x}_i \}_{i = 0}^{\infty}$, by low discrepancy sequences. This accelerates the convergence of $\hatμ_n$ to $μ$ as $n \to \infty$.
This tutorial describes low discrepancy sequences and their quality measures. We demonstrate the performance gains possible with quasi-Monte Carlo methods. Moreover, we describe how to formulate problems to realize the greatest performance gains using quasi-Monte Carlo. We also briefly describe the use of quasi-Monte Carlo methods for problems beyond computing the mean, $μ$.
△ Less
Submitted 5 February, 2025;
originally announced February 2025.
-
Nonparametric Sparse Online Learning of the Koopman Operator
Authors:
Boya Hou,
Sina Sanjari,
Nathan Dahlin,
Alec Koppel,
Subhonmesh Bose
Abstract:
The Koopman operator provides a powerful framework for representing the dynamics of general nonlinear dynamical systems. Data-driven techniques to learn the Koopman operator typically assume that the chosen function space is closed under system dynamics. In this paper, we study the Koopman operator via its action on the reproducing kernel Hilbert space (RKHS), and explore the mis-specified scenari…
▽ More
The Koopman operator provides a powerful framework for representing the dynamics of general nonlinear dynamical systems. Data-driven techniques to learn the Koopman operator typically assume that the chosen function space is closed under system dynamics. In this paper, we study the Koopman operator via its action on the reproducing kernel Hilbert space (RKHS), and explore the mis-specified scenario where the dynamics may escape the chosen function space. We relate the Koopman operator to the conditional mean embeddings (CME) operator and then present an operator stochastic approximation algorithm to learn the Koopman operator iteratively with control over the complexity of the representation. We provide both asymptotic and finite-time last-iterate guarantees of the online sparse learning algorithm with trajectory-based sampling with an analysis that is substantially more involved than that for finite-dimensional stochastic approximation. Numerical examples confirm the effectiveness of the proposed algorithm.
△ Less
Submitted 4 February, 2025; v1 submitted 27 January, 2025;
originally announced January 2025.
-
Automatic Debiased Machine Learning for Smooth Functionals of Nonparametric M-Estimands
Authors:
Lars van der Laan,
Aurelien Bibaut,
Nathan Kallus,
Alex Luedtke
Abstract:
We propose a unified framework for automatic debiased machine learning (autoDML) to perform inference on smooth functionals of infinite-dimensional M-estimands, defined as population risk minimizers over Hilbert spaces. By automating debiased estimation and inference procedures in causal inference and semiparametric statistics, our framework enables practitioners to construct valid estimators for…
▽ More
We propose a unified framework for automatic debiased machine learning (autoDML) to perform inference on smooth functionals of infinite-dimensional M-estimands, defined as population risk minimizers over Hilbert spaces. By automating debiased estimation and inference procedures in causal inference and semiparametric statistics, our framework enables practitioners to construct valid estimators for complex parameters without requiring specialized expertise. The framework supports Neyman-orthogonal loss functions with unknown nuisance parameters requiring data-driven estimation, as well as vector-valued M-estimands involving simultaneous loss minimization across multiple Hilbert space models. We formalize the class of parameters efficiently estimable by autoDML as a novel class of nonparametric projection parameters, defined via orthogonal minimum loss objectives. We introduce three autoDML estimators based on one-step estimation, targeted minimum loss-based estimation, and the method of sieves. For data-driven model selection, we derive a novel decomposition of model approximation error for smooth functionals of M-estimands and propose adaptive debiased machine learning estimators that are superefficient and adaptive to the functional form of the M-estimand. Finally, we illustrate the flexibility of our framework by constructing autoDML estimators for the long-term survival under a beta-geometric model.
△ Less
Submitted 20 January, 2025;
originally announced January 2025.
-
Automatic Double Reinforcement Learning in Semiparametric Markov Decision Processes with Applications to Long-Term Causal Inference
Authors:
Lars van der Laan,
David Hubbard,
Allen Tran,
Nathan Kallus,
Aurélien Bibaut
Abstract:
Estimating long-term causal effects from short-term data is essential for decision-making in healthcare, economics, and industry, where long-term follow-up is often infeasible. Markov Decision Processes (MDPs) offer a principled framework for modeling outcomes as sequences of states, actions, and rewards over time. We introduce a semiparametric extension of Double Reinforcement Learning (DRL) for…
▽ More
Estimating long-term causal effects from short-term data is essential for decision-making in healthcare, economics, and industry, where long-term follow-up is often infeasible. Markov Decision Processes (MDPs) offer a principled framework for modeling outcomes as sequences of states, actions, and rewards over time. We introduce a semiparametric extension of Double Reinforcement Learning (DRL) for statistically efficient, model-robust inference on linear functionals of the Q-function, such as policy values, in infinite-horizon, time-homogeneous MDPs. By imposing semiparametric structure on the Q-function, our method relaxes the strong state overlap assumptions required by fully nonparametric approaches, improving efficiency and stability. To address computational and robustness challenges of minimax nuisance estimation, we develop a novel debiased plug-in estimator based on isotonic Bellman calibration, which integrates fitted Q-iteration with an isotonic regression step. This procedure leverages the Q-function as a data-driven dimension reduction, debiases all linear functionals of interest simultaneously, and enables nonparametric inference without explicit nuisance function estimation. Bellman calibration generalizes isotonic calibration to MDPs and may be of independent interest for prediction in reinforcement learning. Finally, we show that model selection for the Q-function incurs only second-order bias and extend the adaptive debiased machine learning (ADML) framework to MDPs for data-driven learning of semiparametric structure.
△ Less
Submitted 27 April, 2025; v1 submitted 12 January, 2025;
originally announced January 2025.
-
Towards understanding the bias in decision trees
Authors:
Nathan Phelps,
Daniel J. Lizotte,
Douglas G. Woolford
Abstract:
There is a widespread and longstanding belief that machine learning models are biased towards the majority (or negative) class when learning from imbalanced data, leading them to neglect or ignore the minority (or positive) class. In this study, we show that this belief is not necessarily correct for decision trees, and that their bias can actually be in the opposite direction. Motivated by a rece…
▽ More
There is a widespread and longstanding belief that machine learning models are biased towards the majority (or negative) class when learning from imbalanced data, leading them to neglect or ignore the minority (or positive) class. In this study, we show that this belief is not necessarily correct for decision trees, and that their bias can actually be in the opposite direction. Motivated by a recent simulation study that suggested that decision trees can be biased towards the minority class, our paper aims to reconcile the conflict between that study and decades of other works. First, we critically evaluate past literature on this problem, finding that failing to consider the data generating process has led to incorrect conclusions about the bias in decision trees. We then prove that, under specific conditions related to the predictors, decision trees fit to purity and trained on a dataset with only one positive case are biased towards the minority class. Finally, we demonstrate that splits in a decision tree are also biased when there is more than one positive case. Our findings have implications on the use of popular tree-based models, such as random forests.
△ Less
Submitted 28 February, 2025; v1 submitted 8 January, 2025;
originally announced January 2025.
-
Challenges learning from imbalanced data using tree-based models: Prevalence estimates systematically depend on hyperparameters and can be upwardly biased
Authors:
Nathan Phelps,
Daniel J. Lizotte,
Douglas G. Woolford
Abstract:
Imbalanced binary classification problems arise in many fields of study. When using machine learning models for these problems, it is common to subsample the majority class (i.e., undersampling) to create a (more) balanced dataset for model training. This biases the model's predictions because the model learns from a dataset that does not follow the same data generating process as new data. One wa…
▽ More
Imbalanced binary classification problems arise in many fields of study. When using machine learning models for these problems, it is common to subsample the majority class (i.e., undersampling) to create a (more) balanced dataset for model training. This biases the model's predictions because the model learns from a dataset that does not follow the same data generating process as new data. One way of accounting for this bias is to analytically map the resulting predictions to new values based on the sampling rate for the majority class, which was used to create the training dataset. While this approach may work well for some machine learning models, we have found that calibrating a random forest this way has unintended negative consequences, including prevalence estimates that can be upwardly biased. These prevalence estimates depend on both i) the number of predictors considered at each split in the random forest; and ii) the sampling rate used. We explain the former using known properties of random forests and analytical calibration. However, in investigating the latter issue, we made a surprising discovery - contrary to the widespread belief that decision trees are biased towards the majority class, they actually can be biased towards the minority class.
△ Less
Submitted 17 December, 2024;
originally announced December 2024.
-
Long-time accuracy of ensemble Kalman filters for chaotic and machine-learned dynamical systems
Authors:
Daniel Sanz-Alonso,
Nathan Waniorek
Abstract:
Filtering is concerned with online estimation of the state of a dynamical system from partial and noisy observations. In applications where the state is high dimensional, ensemble Kalman filters are often the method of choice. This paper establishes long-time accuracy of ensemble Kalman filters. We introduce conditions on the dynamics and the observations under which the estimation error remains s…
▽ More
Filtering is concerned with online estimation of the state of a dynamical system from partial and noisy observations. In applications where the state is high dimensional, ensemble Kalman filters are often the method of choice. This paper establishes long-time accuracy of ensemble Kalman filters. We introduce conditions on the dynamics and the observations under which the estimation error remains small in the long-time horizon. Our theory covers a wide class of partially-observed chaotic dynamical systems, which includes the Navier-Stokes equations and Lorenz models. In addition, we prove long-time accuracy of ensemble Kalman filters with surrogate dynamics, thus validating the use of machine-learned forecast models in ensemble data assimilation.
△ Less
Submitted 18 December, 2024;
originally announced December 2024.
-
Proximal Iteration for Nonlinear Adaptive Lasso
Authors:
Nathan Wycoff,
Lisa O. Singh,
Ali Arab,
Katharine M. Donato
Abstract:
Augmenting a smooth cost function with an $\ell_1$ penalty allows analysts to efficiently conduct estimation and variable selection simultaneously in sophisticated models and can be efficiently implemented using proximal gradient methods. However, one drawback of the $\ell_1$ penalty is bias: nonzero parameters are underestimated in magnitude, motivating techniques such as the Adaptive Lasso which…
▽ More
Augmenting a smooth cost function with an $\ell_1$ penalty allows analysts to efficiently conduct estimation and variable selection simultaneously in sophisticated models and can be efficiently implemented using proximal gradient methods. However, one drawback of the $\ell_1$ penalty is bias: nonzero parameters are underestimated in magnitude, motivating techniques such as the Adaptive Lasso which endow each parameter with its own penalty coefficient. But it's not clear how these parameter-specific penalties should be set in complex models. In this article, we study the approach of treating the penalty coefficients as additional decision variables to be learned in a \textit{Maximum a Posteriori} manner, developing a proximal gradient approach to joint optimization of these together with the parameters of any differentiable cost function. Beyond reducing bias in estimates, this procedure can also encourage arbitrary sparsity structure via a prior on the penalty coefficients. We compare our method to implementations of specific sparsity structures for non-Gaussian regression on synthetic and real datasets, finding our more general method to be competitive in terms of both speed and accuracy. We then consider nonlinear models for two case studies: COVID-19 vaccination behavior and international refugee movement, highlighting the applicability of this approach to complex problems and intricate sparsity structures.
△ Less
Submitted 7 December, 2024;
originally announced December 2024.
-
Provable Tempered Overfitting of Minimal Nets and Typical Nets
Authors:
Itamar Harel,
William M. Hoza,
Gal Vardi,
Itay Evron,
Nathan Srebro,
Daniel Soudry
Abstract:
We study the overfitting behavior of fully connected deep Neural Networks (NNs) with binary weights fitted to perfectly classify a noisy training set. We consider interpolation using both the smallest NN (having the minimal number of weights) and a random interpolating NN. For both learning rules, we prove overfitting is tempered. Our analysis rests on a new bound on the size of a threshold circui…
▽ More
We study the overfitting behavior of fully connected deep Neural Networks (NNs) with binary weights fitted to perfectly classify a noisy training set. We consider interpolation using both the smallest NN (having the minimal number of weights) and a random interpolating NN. For both learning rules, we prove overfitting is tempered. Our analysis rests on a new bound on the size of a threshold circuit consistent with a partial function. To the best of our knowledge, ours are the first theoretical results on benign or tempered overfitting that: (1) apply to deep NNs, and (2) do not require a very high or very low input dimension.
△ Less
Submitted 24 October, 2024;
originally announced October 2024.
-
Using Platt's scaling for calibration after undersampling -- limitations and how to address them
Authors:
Nathan Phelps,
Daniel J. Lizotte,
Douglas G. Woolford
Abstract:
When modelling data where the response is dichotomous and highly imbalanced, response-based sampling where a subset of the majority class is retained (i.e., undersampling) is often used to create more balanced training datasets prior to modelling. However, the models fit to this undersampled data, which we refer to as base models, generate predictions that are severely biased. There are several ca…
▽ More
When modelling data where the response is dichotomous and highly imbalanced, response-based sampling where a subset of the majority class is retained (i.e., undersampling) is often used to create more balanced training datasets prior to modelling. However, the models fit to this undersampled data, which we refer to as base models, generate predictions that are severely biased. There are several calibration methods that can be used to combat this bias, one of which is Platt's scaling. Here, a logistic regression model is used to model the relationship between the base model's original predictions and the response. Despite its popularity for calibrating models after undersampling, Platt's scaling was not designed for this purpose. Our work presents what we believe is the first detailed study focused on the validity of using Platt's scaling to calibrate models after undersampling. We show analytically, as well as via a simulation study and a case study, that Platt's scaling should not be used for calibration after undersampling without critical thought. If Platt's scaling would have been able to successfully calibrate the base model had it been trained on the entire dataset (i.e., without undersampling), then Platt's scaling might be appropriate for calibration after undersampling. If this is not the case, we recommend a modified version of Platt's scaling that fits a logistic generalized additive model to the logit of the base model's predictions, as it is both theoretically motivated and performed well across the settings considered in our study.
△ Less
Submitted 4 December, 2024; v1 submitted 22 October, 2024;
originally announced October 2024.
-
Sacred and Profane: from the Involutive Theory of MCMC to Helpful Hamiltonian Hacks
Authors:
Nathan E. Glatt-Holtz,
Andrew J. Holbrook,
Justin A. Krometis,
Cecilia F. Mondaini,
Ami Sheth
Abstract:
In the first edition of this Handbook, two remarkable chapters consider seemingly distinct yet deeply connected subjects ...
In the first edition of this Handbook, two remarkable chapters consider seemingly distinct yet deeply connected subjects ...
△ Less
Submitted 29 October, 2024; v1 submitted 22 October, 2024;
originally announced October 2024.
-
Reward Maximization for Pure Exploration: Minimax Optimal Good Arm Identification for Nonparametric Multi-Armed Bandits
Authors:
Brian Cho,
Dominik Meier,
Kyra Gan,
Nathan Kallus
Abstract:
In multi-armed bandits, the tasks of reward maximization and pure exploration are often at odds with each other. The former focuses on exploiting arms with the highest means, while the latter may require constant exploration across all arms. In this work, we focus on good arm identification (GAI), a practical bandit inference objective that aims to label arms with means above a threshold as quickl…
▽ More
In multi-armed bandits, the tasks of reward maximization and pure exploration are often at odds with each other. The former focuses on exploiting arms with the highest means, while the latter may require constant exploration across all arms. In this work, we focus on good arm identification (GAI), a practical bandit inference objective that aims to label arms with means above a threshold as quickly as possible. We show that GAI can be efficiently solved by combining a reward-maximizing sampling algorithm with a novel nonparametric anytime-valid sequential test for labeling arm means. We first establish that our sequential test maintains error control under highly nonparametric assumptions and asymptotically achieves the minimax optimal e-power, a notion of power for anytime-valid tests. Next, by pairing regret-minimizing sampling schemes with our sequential test, we provide an approach that achieves minimax optimal stopping times for labeling arms with means above a threshold, under an error probability constraint. Our empirical results validate our approach beyond the minimax setting, reducing the expected number of samples for all stopping times by at least 50% across both synthetic and real-world settings.
△ Less
Submitted 20 October, 2024;
originally announced October 2024.
-
Anytime-Valid Continuous-Time Confidence Processes for Inhomogeneous Poisson Processes
Authors:
Michael Lindon,
Nathan Kallus
Abstract:
Motivated by monitoring the arrival of incoming adverse events such as customer support calls or crash reports from users exposed to an experimental product change, we consider sequential hypothesis testing of continuous-time inhomogeneous Poisson point processes. Specifically, we provide an interval-valued confidence process $C^α(t)$ over continuous time $t$ for the cumulative arrival rate…
▽ More
Motivated by monitoring the arrival of incoming adverse events such as customer support calls or crash reports from users exposed to an experimental product change, we consider sequential hypothesis testing of continuous-time inhomogeneous Poisson point processes. Specifically, we provide an interval-valued confidence process $C^α(t)$ over continuous time $t$ for the cumulative arrival rate $Λ(t) = \int_0^t λ(s) \mathrm{d}s$ with a continuous-time anytime-valid coverage guarantee $\mathbb{P}[Λ(t) \in C^α(t) \, \forall t >0] \geq 1-α$. We extend our results to compare two independent arrival processes by constructing multivariate confidence processes and a closed-form $e$-process for testing the equality of rates with a time-uniform Type-I error guarantee at a nominal $α$. We characterize the asymptotic growth rate of the proposed $e$-process under the alternative and show that it has power 1 when the average rates of the two Poisson process differ in the limit. We also observe a complementary relationship between our multivariate confidence process and the universal inference $e$-process for testing composite null hypotheses.
△ Less
Submitted 11 October, 2024;
originally announced October 2024.
-
Adjusting Regression Models for Conditional Uncertainty Calibration
Authors:
Ruijiang Gao,
Mingzhang Yin,
James McInerney,
Nathan Kallus
Abstract:
Conformal Prediction methods have finite-sample distribution-free marginal coverage guarantees. However, they generally do not offer conditional coverage guarantees, which can be important for high-stakes decisions. In this paper, we propose a novel algorithm to train a regression function to improve the conditional coverage after applying the split conformal prediction procedure. We establish an…
▽ More
Conformal Prediction methods have finite-sample distribution-free marginal coverage guarantees. However, they generally do not offer conditional coverage guarantees, which can be important for high-stakes decisions. In this paper, we propose a novel algorithm to train a regression function to improve the conditional coverage after applying the split conformal prediction procedure. We establish an upper bound for the miscoverage gap between the conditional coverage and the nominal coverage rate and propose an end-to-end algorithm to control this upper bound. We demonstrate the efficacy of our method empirically on synthetic and real-world datasets.
△ Less
Submitted 25 September, 2024;
originally announced September 2024.
-
Physics-informed kernel learning
Authors:
Nathan Doumèche,
Francis Bach,
Gérard Biau,
Claire Boyer
Abstract:
Physics-informed machine learning typically integrates physical priors into the learning process by minimizing a loss function that includes both a data-driven term and a partial differential equation (PDE) regularization. Building on the formulation of the problem as a kernel regression task, we use Fourier methods to approximate the associated kernel, and propose a tractable estimator that minim…
▽ More
Physics-informed machine learning typically integrates physical priors into the learning process by minimizing a loss function that includes both a data-driven term and a partial differential equation (PDE) regularization. Building on the formulation of the problem as a kernel regression task, we use Fourier methods to approximate the associated kernel, and propose a tractable estimator that minimizes the physics-informed risk function. We refer to this approach as physics-informed kernel learning (PIKL). This framework provides theoretical guarantees, enabling the quantification of the physical prior's impact on convergence speed. We demonstrate the numerical performance of the PIKL estimator through simulations, both in the context of hybrid modeling and in solving PDEs. In particular, we show that PIKL can outperform physics-informed neural networks in terms of both accuracy and computation time. Additionally, we identify cases where PIKL surpasses traditional PDE solvers, particularly in scenarios with noisy boundary conditions.
△ Less
Submitted 20 September, 2024;
originally announced September 2024.
-
The Central Role of the Loss Function in Reinforcement Learning
Authors:
Kaiwen Wang,
Nathan Kallus,
Wen Sun
Abstract:
This paper illustrates the central role of loss functions in data-driven decision making, providing a comprehensive survey on their influence in cost-sensitive classification (CSC) and reinforcement learning (RL). We demonstrate how different regression loss functions affect the sample efficiency and adaptivity of value-based decision making algorithms. Across multiple settings, we prove that algo…
▽ More
This paper illustrates the central role of loss functions in data-driven decision making, providing a comprehensive survey on their influence in cost-sensitive classification (CSC) and reinforcement learning (RL). We demonstrate how different regression loss functions affect the sample efficiency and adaptivity of value-based decision making algorithms. Across multiple settings, we prove that algorithms using the binary cross-entropy loss achieve first-order bounds scaling with the optimal policy's cost and are much more efficient than the commonly used squared loss. Moreover, we prove that distributional algorithms using the maximum likelihood loss achieve second-order bounds scaling with the policy variance and are even sharper than first-order bounds. This in particular proves the benefits of distributional RL. We hope that this paper serves as a guide analyzing decision making algorithms with varying loss functions, and can inspire the reader to seek out better loss functions to improve any decision making algorithm.
△ Less
Submitted 4 April, 2025; v1 submitted 19 September, 2024;
originally announced September 2024.
-
Overfitting Behaviour of Gaussian Kernel Ridgeless Regression: Varying Bandwidth or Dimensionality
Authors:
Marko Medvedev,
Gal Vardi,
Nathan Srebro
Abstract:
We consider the overfitting behavior of minimum norm interpolating solutions of Gaussian kernel ridge regression (i.e. kernel ridgeless regression), when the bandwidth or input dimension varies with the sample size. For fixed dimensions, we show that even with varying or tuned bandwidth, the ridgeless solution is never consistent and, at least with large enough noise, always worse than the null pr…
▽ More
We consider the overfitting behavior of minimum norm interpolating solutions of Gaussian kernel ridge regression (i.e. kernel ridgeless regression), when the bandwidth or input dimension varies with the sample size. For fixed dimensions, we show that even with varying or tuned bandwidth, the ridgeless solution is never consistent and, at least with large enough noise, always worse than the null predictor. For increasing dimension, we give a generic characterization of the overfitting behavior for any scaling of the dimension with sample size. We use this to provide the first example of benign overfitting using the Gaussian kernel with sub-polynomial scaling dimension. All our results are under the Gaussian universality ansatz and the (non-rigorous) risk predictions in terms of the kernel eigenstructure.
△ Less
Submitted 5 September, 2024;
originally announced September 2024.
-
CRUD-Capable Mobile Apps with R and shinyMobile: a Case Study in Rapid Prototyping
Authors:
Nathan Henry
Abstract:
"Harden" is a Progressive Web Application (PWA) for Ecological Momentary Assessment (EMA) developed mostly in R, which runs on all platforms with an internet connection, including iOS and Android. It leverages the shinyMobile package for creating a reactive mobile user interface (UI), PostgreSQL for the database backend, and Google Cloud Run for scalable hosting in the cloud, with serverless execu…
▽ More
"Harden" is a Progressive Web Application (PWA) for Ecological Momentary Assessment (EMA) developed mostly in R, which runs on all platforms with an internet connection, including iOS and Android. It leverages the shinyMobile package for creating a reactive mobile user interface (UI), PostgreSQL for the database backend, and Google Cloud Run for scalable hosting in the cloud, with serverless execution. Using this technology stack, it was possible to rapidly prototype a fully CRUD-capable (Create, Read, Update, Delete) mobile app, with persistent user data across sessions, interactive graphs, and real-time statistical calculation. This framework is compared with current alternative frameworks for creating data science apps; it is argued that the shinyMobile package provides one of the most efficient methods for rapid prototyping and creation of statistical mobile apps that require advanced graphing capabilities. This paper outlines the methodology used to create the Harden application, and discusses the advantages and limitations of the shinyMobile approach to app development. It is hoped that this information will encourage other programmers versed in R to consider developing mobile apps with this framework.
△ Less
Submitted 31 August, 2024;
originally announced September 2024.
-
CSPI-MT: Calibrated Safe Policy Improvement with Multiple Testing for Threshold Policies
Authors:
Brian M Cho,
Ana-Roxana Pop,
Kyra Gan,
Sam Corbett-Davies,
Israel Nir,
Ariel Evnine,
Nathan Kallus
Abstract:
When modifying existing policies in high-risk settings, it is often necessary to ensure with high certainty that the newly proposed policy improves upon a baseline, such as the status quo. In this work, we consider the problem of safe policy improvement, where one only adopts a new policy if it is deemed to be better than the specified baseline with at least pre-specified probability. We focus on…
▽ More
When modifying existing policies in high-risk settings, it is often necessary to ensure with high certainty that the newly proposed policy improves upon a baseline, such as the status quo. In this work, we consider the problem of safe policy improvement, where one only adopts a new policy if it is deemed to be better than the specified baseline with at least pre-specified probability. We focus on threshold policies, a ubiquitous class of policies with applications in economics, healthcare, and digital advertising. Existing methods rely on potentially underpowered safety checks and limit the opportunities for finding safe improvements, so too often they must revert to the baseline to maintain safety. We overcome these issues by leveraging the most powerful safety test in the asymptotic regime and allowing for multiple candidates to be tested for improvement over the baseline. We show that in adversarial settings, our approach controls the rate of adopting a policy worse than the baseline to the pre-specified error level, even in moderate sample sizes. We present CSPI and CSPI-MT, two novel heuristics for selecting cutoff(s) to maximize the policy improvement from baseline. We demonstrate through both synthetic and external datasets that our approaches improve both the detection rates of safe policies and the realized improvement, particularly under stringent safety requirements and low signal-to-noise conditions.
△ Less
Submitted 21 August, 2024;
originally announced August 2024.
-
Bayesian Causal Forests for Longitudinal Data: Assessing the Impact of Part-Time Work on Growth in High School Mathematics Achievement
Authors:
Nathan McJames,
Ann O'Shea,
Andrew Parnell
Abstract:
Modelling growth in student achievement is a significant challenge in the field of education. Understanding how interventions or experiences such as part-time work can influence this growth is also important. Traditional methods like difference-in-differences are effective for estimating causal effects from longitudinal data. Meanwhile, Bayesian non-parametric methods have recently become popular…
▽ More
Modelling growth in student achievement is a significant challenge in the field of education. Understanding how interventions or experiences such as part-time work can influence this growth is also important. Traditional methods like difference-in-differences are effective for estimating causal effects from longitudinal data. Meanwhile, Bayesian non-parametric methods have recently become popular for estimating causal effects from single time point observational studies. However, there remains a scarcity of methods capable of combining the strengths of these two approaches to flexibly estimate heterogeneous causal effects from longitudinal data. Motivated by two waves of data from the High School Longitudinal Study, the NCES' most recent longitudinal study which tracks a representative sample of over 20,000 students in the US, our study introduces a longitudinal extension of Bayesian Causal Forests. This model allows for the flexible identification of both individual growth in mathematical ability and the effects of participation in part-time work. Simulation studies demonstrate the predictive performance and reliable uncertainty quantification of the proposed model. Results reveal the negative impact of part time work for most students, but hint at potential benefits for those students with an initially low sense of school belonging. Clear signs of a widening achievement gap between students with high and low academic achievement are also identified. Potential policy implications are discussed, along with promising areas for future research.
△ Less
Submitted 16 July, 2024;
originally announced July 2024.
-
Estimating Heterogeneous Treatment Effects by Combining Weak Instruments and Observational Data
Authors:
Miruna Oprescu,
Nathan Kallus
Abstract:
Accurately predicting conditional average treatment effects (CATEs) is crucial in personalized medicine and digital platform analytics. Since the treatments of interest often cannot be directly randomized, observational data is leveraged to learn CATEs, but this approach can incur significant bias from unobserved confounding. One strategy to overcome these limitations is to leverage instrumental v…
▽ More
Accurately predicting conditional average treatment effects (CATEs) is crucial in personalized medicine and digital platform analytics. Since the treatments of interest often cannot be directly randomized, observational data is leveraged to learn CATEs, but this approach can incur significant bias from unobserved confounding. One strategy to overcome these limitations is to leverage instrumental variables (IVs) as latent quasi-experiments, such as randomized intent-to-treat assignments or randomized product recommendations. This approach, on the other hand, can suffer from low compliance, $\textit{i.e.}$, IV weakness. Some subgroups may even exhibit zero compliance, meaning we cannot instrument for their CATEs at all. In this paper, we develop a novel approach to combine IV and observational data to enable reliable CATE estimation in the presence of unobserved confounding in the observational data and low compliance in the IV data, including no compliance for some subgroups. We propose a two-stage framework that first learns $\textit{biased}$ CATEs from the observational data, and then applies a compliance-weighted correction using IV data, effectively leveraging IV strength variability across covariates. We characterize the convergence rates of our method and validate its effectiveness through a simulation study. Additionally, we demonstrate its utility with real data by analyzing the heterogeneous effects of 401(k) plan participation on wealth.
△ Less
Submitted 1 November, 2024; v1 submitted 10 June, 2024;
originally announced June 2024.
-
The Price of Implicit Bias in Adversarially Robust Generalization
Authors:
Nikolaos Tsilivis,
Natalie Frank,
Nathan Srebro,
Julia Kempe
Abstract:
We study the implicit bias of optimization in robust empirical risk minimization (robust ERM) and its connection with robust generalization. In classification settings under adversarial perturbations with linear models, we study what type of regularization should ideally be applied for a given perturbation set to improve (robust) generalization. We then show that the implicit bias of optimization…
▽ More
We study the implicit bias of optimization in robust empirical risk minimization (robust ERM) and its connection with robust generalization. In classification settings under adversarial perturbations with linear models, we study what type of regularization should ideally be applied for a given perturbation set to improve (robust) generalization. We then show that the implicit bias of optimization in robust ERM can significantly affect the robustness of the model and identify two ways this can happen; either through the optimization algorithm or the architecture. We verify our predictions in simulations with synthetic data and experimentally study the importance of implicit bias in robust ERM with deep neural networks.
△ Less
Submitted 7 June, 2024;
originally announced June 2024.
-
Aligning Multiclass Neural Network Classifier Criterion with Task Performance Metrics
Authors:
Deyuan Li,
Taesoo Daniel Lee,
Marynel Vázquez,
Nathan Tsoi
Abstract:
Multiclass neural network classifiers are typically trained using cross-entropy loss but evaluated using metrics derived from the confusion matrix, such as Accuracy, $F_β$-Score, and Matthews Correlation Coefficient. This mismatch between the training objective and evaluation metric can lead to suboptimal performance, particularly when the user's priorities differ from what cross-entropy implicitl…
▽ More
Multiclass neural network classifiers are typically trained using cross-entropy loss but evaluated using metrics derived from the confusion matrix, such as Accuracy, $F_β$-Score, and Matthews Correlation Coefficient. This mismatch between the training objective and evaluation metric can lead to suboptimal performance, particularly when the user's priorities differ from what cross-entropy implicitly optimizes. For example, in the presence of class imbalance, $F_1$-Score may be preferred over Accuracy. Similarly, given a preference towards precision, the $F_{β=0.25}$-Score will better reflect this preference than $F_1$-Score. However, standard cross-entropy loss does not accommodate such a preference. Building on prior work leveraging soft-set confusion matrices and a continuous piecewise-linear Heaviside approximation, we propose Evaluation Aligned Surrogate Training (EAST), a novel approach to train multiclass classifiers using close surrogates of confusion-matrix based metrics, thereby aligning a neural network classifier's predictions more closely to a target evaluation metric than typical cross-entropy loss. EAST introduces three key innovations: First, we propose a novel dynamic thresholding approach during training. Second, we propose using a multiclass soft-set confusion matrix. Third, we introduce an annealing process that gradually aligns the surrogate loss with the target evaluation metric. Our theoretical analysis shows that EAST results in consistent estimators of the target evaluation metric. Furthermore, we show that the learned network parameters converge asymptotically to values that optimize for the target evaluation metric. Extensive experiments validate the effectiveness of our approach, demonstrating improved alignment between training objectives and evaluation metrics, while outperforming existing methods across many datasets.
△ Less
Submitted 26 May, 2025; v1 submitted 31 May, 2024;
originally announced May 2024.
-
Enhancing Generative Molecular Design via Uncertainty-guided Fine-tuning of Variational Autoencoders
Authors:
A N M Nafiz Abeer,
Sanket Jantre,
Nathan M Urban,
Byung-Jun Yoon
Abstract:
In recent years, deep generative models have been successfully adopted for various molecular design tasks, particularly in the life and material sciences. A critical challenge for pre-trained generative molecular design (GMD) models is to fine-tune them to be better suited for downstream design tasks aimed at optimizing specific molecular properties. However, redesigning and training an existing e…
▽ More
In recent years, deep generative models have been successfully adopted for various molecular design tasks, particularly in the life and material sciences. A critical challenge for pre-trained generative molecular design (GMD) models is to fine-tune them to be better suited for downstream design tasks aimed at optimizing specific molecular properties. However, redesigning and training an existing effective generative model from scratch for each new design task is impractical. Furthermore, the black-box nature of typical downstream tasks$\unicode{x2013}$such as property prediction$\unicode{x2013}$makes it nontrivial to optimize the generative model in a task-specific manner. In this work, we propose a novel approach for a model uncertainty-guided fine-tuning of a pre-trained variational autoencoder (VAE)-based GMD model through performance feedback in an active learning setting. The main idea is to quantify model uncertainty in the generative model, which is made efficient by working within a low-dimensional active subspace of the high-dimensional VAE parameters explaining most of the variability in the model's output. The inclusion of model uncertainty expands the space of viable molecules through decoder diversity. We then explore the resulting model uncertainty class via black-box optimization made tractable by low-dimensionality of the active subspace. This enables us to identify and leverage a diverse set of high-performing models to generate enhanced molecules. Empirical results across six target molecular properties, using multiple VAE-based generative models, demonstrate that our uncertainty-guided fine-tuning approach consistently outperforms the original pre-trained models.
△ Less
Submitted 30 May, 2024;
originally announced May 2024.