-
Emergence and scaling laws in SGD learning of shallow neural networks
Authors:
Yunwei Ren,
Eshaan Nichani,
Denny Wu,
Jason D. Lee
Abstract:
We study the complexity of online stochastic gradient descent (SGD) for learning a two-layer neural network with $P$ neurons on isotropic Gaussian data: $f_*(\boldsymbol{x}) = \sum_{p=1}^P a_p\cdot σ(\langle\boldsymbol{x},\boldsymbol{v}_p^*\rangle)$, $\boldsymbol{x} \sim \mathcal{N}(0,\boldsymbol{I}_d)$, where the activation $σ:\mathbb{R}\to\mathbb{R}$ is an even function with information exponent…
▽ More
We study the complexity of online stochastic gradient descent (SGD) for learning a two-layer neural network with $P$ neurons on isotropic Gaussian data: $f_*(\boldsymbol{x}) = \sum_{p=1}^P a_p\cdot σ(\langle\boldsymbol{x},\boldsymbol{v}_p^*\rangle)$, $\boldsymbol{x} \sim \mathcal{N}(0,\boldsymbol{I}_d)$, where the activation $σ:\mathbb{R}\to\mathbb{R}$ is an even function with information exponent $k_*>2$ (defined as the lowest degree in the Hermite expansion), $\{\boldsymbol{v}^*_p\}_{p\in[P]}\subset \mathbb{R}^d$ are orthonormal signal directions, and the non-negative second-layer coefficients satisfy $\sum_{p} a_p^2=1$. We focus on the challenging ``extensive-width'' regime $P\gg 1$ and permit diverging condition number in the second-layer, covering as a special case the power-law scaling $a_p\asymp p^{-β}$ where $β\in\mathbb{R}_{\ge 0}$. We provide a precise analysis of SGD dynamics for the training of a student two-layer network to minimize the mean squared error (MSE) objective, and explicitly identify sharp transition times to recover each signal direction. In the power-law setting, we characterize scaling law exponents for the MSE loss with respect to the number of training samples and SGD steps, as well as the number of parameters in the student neural network. Our analysis entails that while the learning of individual teacher neurons exhibits abrupt transitions, the juxtaposition of $P\gg 1$ emergent learning curves at different timescales leads to a smooth scaling law in the cumulative objective.
△ Less
Submitted 28 April, 2025;
originally announced April 2025.
-
What Makes a Reward Model a Good Teacher? An Optimization Perspective
Authors:
Noam Razin,
Zixuan Wang,
Hubert Strauss,
Stanley Wei,
Jason D. Lee,
Sanjeev Arora
Abstract:
The success of Reinforcement Learning from Human Feedback (RLHF) critically depends on the quality of the reward model. While this quality is primarily evaluated through accuracy, it remains unclear whether accuracy fully captures what makes a reward model an effective teacher. We address this question from an optimization perspective. First, we prove that regardless of how accurate a reward model…
▽ More
The success of Reinforcement Learning from Human Feedback (RLHF) critically depends on the quality of the reward model. While this quality is primarily evaluated through accuracy, it remains unclear whether accuracy fully captures what makes a reward model an effective teacher. We address this question from an optimization perspective. First, we prove that regardless of how accurate a reward model is, if it induces low reward variance, then the RLHF objective suffers from a flat landscape. Consequently, even a perfectly accurate reward model can lead to extremely slow optimization, underperforming less accurate models that induce higher reward variance. We additionally show that a reward model that works well for one language model can induce low reward variance, and thus a flat objective landscape, for another. These results establish a fundamental limitation of evaluating reward models solely based on accuracy or independently of the language model they guide. Experiments using models of up to 8B parameters corroborate our theory, demonstrating the interplay between reward variance, accuracy, and reward maximization rate. Overall, our findings highlight that beyond accuracy, a reward model needs to induce sufficient variance for efficient optimization.
△ Less
Submitted 19 March, 2025;
originally announced March 2025.
-
Discrepancies are Virtue: Weak-to-Strong Generalization through Lens of Intrinsic Dimension
Authors:
Yijun Dong,
Yicheng Li,
Yunai Li,
Jason D. Lee,
Qi Lei
Abstract:
Weak-to-strong (W2S) generalization is a type of finetuning (FT) where a strong (large) student model is trained on pseudo-labels generated by a weak teacher. Surprisingly, W2S FT often outperforms the weak teacher. We seek to understand this phenomenon through the observation that FT often occurs in intrinsically low-dimensional spaces. Leveraging the low intrinsic dimensionality of FT, we analyz…
▽ More
Weak-to-strong (W2S) generalization is a type of finetuning (FT) where a strong (large) student model is trained on pseudo-labels generated by a weak teacher. Surprisingly, W2S FT often outperforms the weak teacher. We seek to understand this phenomenon through the observation that FT often occurs in intrinsically low-dimensional spaces. Leveraging the low intrinsic dimensionality of FT, we analyze W2S in the ridgeless regression setting from a variance reduction perspective. For a strong student-weak teacher pair with sufficiently expressive low-dimensional feature subspaces $\mathcal{V}_s, \mathcal{V}_w$, we provide an exact characterization of the variance that dominates the generalization error of W2S. This unveils a virtue of discrepancy between the strong and weak models in W2S: the variance of the weak teacher is inherited by the strong student in $\mathcal{V}_s \cap \mathcal{V}_w$, while reduced by a factor of $\dim(\mathcal{V}_s)/N$ in the subspace of discrepancy $\mathcal{V}_w \setminus \mathcal{V}_s$ with $N$ pseudo-labels for W2S. Our analysis further casts light on the sample complexities and the scaling of performance gap recovery in W2S. The analysis is supported by experiments on synthetic regression problems, as well as real vision and NLP tasks.
△ Less
Submitted 22 May, 2025; v1 submitted 7 February, 2025;
originally announced February 2025.
-
Understanding Factual Recall in Transformers via Associative Memories
Authors:
Eshaan Nichani,
Jason D. Lee,
Alberto Bietti
Abstract:
Large language models have demonstrated an impressive ability to perform factual recall. Prior work has found that transformers trained on factual recall tasks can store information at a rate proportional to their parameter count. In our work, we show that shallow transformers can use a combination of associative memories to obtain such near optimal storage capacity. We begin by proving that the s…
▽ More
Large language models have demonstrated an impressive ability to perform factual recall. Prior work has found that transformers trained on factual recall tasks can store information at a rate proportional to their parameter count. In our work, we show that shallow transformers can use a combination of associative memories to obtain such near optimal storage capacity. We begin by proving that the storage capacities of both linear and MLP associative memories scale linearly with parameter count. We next introduce a synthetic factual recall task, and prove that a transformer with a single layer of self-attention followed by an MLP can obtain 100% accuracy on the task whenever either the total number of self-attention parameters or MLP parameters scales (up to log factors) linearly with the number of facts. In particular, the transformer can trade off between using the value matrices or the MLP as an associative memory to store the dataset of facts. We complement these expressivity results with an analysis of the gradient flow trajectory of a simplified linear attention model trained on our factual recall task, where we show that the model exhibits sequential learning behavior.
△ Less
Submitted 9 December, 2024;
originally announced December 2024.
-
Anytime Acceleration of Gradient Descent
Authors:
Zihan Zhang,
Jason D. Lee,
Simon S. Du,
Yuxin Chen
Abstract:
This work investigates stepsize-based acceleration of gradient descent with {\em anytime} convergence guarantees. For smooth (non-strongly) convex optimization, we propose a stepsize schedule that allows gradient descent to achieve convergence guarantees of $O(T^{-1.119})$ for any stopping time $T$, where the stepsize schedule is predetermined without prior knowledge of the stopping time. This res…
▽ More
This work investigates stepsize-based acceleration of gradient descent with {\em anytime} convergence guarantees. For smooth (non-strongly) convex optimization, we propose a stepsize schedule that allows gradient descent to achieve convergence guarantees of $O(T^{-1.119})$ for any stopping time $T$, where the stepsize schedule is predetermined without prior knowledge of the stopping time. This result provides an affirmative answer to a COLT open problem \citep{kornowski2024open} regarding whether stepsize-based acceleration can yield anytime convergence rates of $o(T^{-1})$. We further extend our theory to yield anytime convergence guarantees of $\exp(-Ω(T/κ^{0.893}))$ for smooth and strongly convex optimization, with $κ$ being the condition number.
△ Less
Submitted 8 December, 2024; v1 submitted 26 November, 2024;
originally announced November 2024.
-
Learning Hierarchical Polynomials of Multiple Nonlinear Features with Three-Layer Networks
Authors:
Hengyu Fu,
Zihao Wang,
Eshaan Nichani,
Jason D. Lee
Abstract:
In deep learning theory, a critical question is to understand how neural networks learn hierarchical features. In this work, we study the learning of hierarchical polynomials of \textit{multiple nonlinear features} using three-layer neural networks. We examine a broad class of functions of the form $f^{\star}=g^{\star}\circ \bp$, where $\bp:\mathbb{R}^{d} \rightarrow \mathbb{R}^{r}$ represents mul…
▽ More
In deep learning theory, a critical question is to understand how neural networks learn hierarchical features. In this work, we study the learning of hierarchical polynomials of \textit{multiple nonlinear features} using three-layer neural networks. We examine a broad class of functions of the form $f^{\star}=g^{\star}\circ \bp$, where $\bp:\mathbb{R}^{d} \rightarrow \mathbb{R}^{r}$ represents multiple quadratic features with $r \ll d$ and $g^{\star}:\mathbb{R}^{r}\rightarrow \mathbb{R}$ is a polynomial of degree $p$. This can be viewed as a nonlinear generalization of the multi-index model \citep{damian2022neural}, and also an expansion upon previous work that focused only on a single nonlinear feature, i.e. $r = 1$ \citep{nichani2023provable,wang2023learning}.
Our primary contribution shows that a three-layer neural network trained via layerwise gradient descent suffices for
\begin{itemize}\item complete recovery of the space spanned by the nonlinear features
\item efficient learning of the target function $f^{\star}=g^{\star}\circ \bp$ or transfer learning of $f=g\circ \bp$ with a different link function
\end{itemize} within $\widetilde{\cO}(d^4)$ samples and polynomial time. For such hierarchical targets, our result substantially improves the sample complexity $Θ(d^{2p})$ of the kernel methods, demonstrating the power of efficient feature learning. It is important to highlight that{ our results leverage novel techniques and thus manage to go beyond all prior settings} such as single-index and multi-index models as well as models depending just on one nonlinear feature, contributing to a more comprehensive understanding of feature learning in deep learning.
△ Less
Submitted 26 November, 2024;
originally announced November 2024.
-
Understanding Optimization in Deep Learning with Central Flows
Authors:
Jeremy M. Cohen,
Alex Damian,
Ameet Talwalkar,
Zico Kolter,
Jason D. Lee
Abstract:
Optimization in deep learning remains poorly understood, even in the simple setting of deterministic (i.e. full-batch) training. A key difficulty is that much of an optimizer's behavior is implicitly determined by complex oscillatory dynamics, referred to as the "edge of stability." The main contribution of this paper is to show that an optimizer's implicit behavior can be explicitly captured by a…
▽ More
Optimization in deep learning remains poorly understood, even in the simple setting of deterministic (i.e. full-batch) training. A key difficulty is that much of an optimizer's behavior is implicitly determined by complex oscillatory dynamics, referred to as the "edge of stability." The main contribution of this paper is to show that an optimizer's implicit behavior can be explicitly captured by a "central flow:" a differential equation which models the time-averaged optimization trajectory. We show that these flows can empirically predict long-term optimization trajectories of generic neural networks with a high degree of numerical accuracy. By interpreting these flows, we reveal for the first time 1) the precise sense in which RMSProp adapts to the local loss landscape, and 2) an "acceleration via regularization" mechanism, wherein adaptive optimizers implicitly navigate towards low-curvature regions in which they can take larger steps. This mechanism is key to the efficacy of these adaptive optimizers. Overall, we believe that central flows constitute a promising tool for reasoning about optimization in deep learning.
△ Less
Submitted 31 October, 2024;
originally announced October 2024.
-
Learning Orthogonal Multi-Index Models: A Fine-Grained Information Exponent Analysis
Authors:
Yunwei Ren,
Jason D. Lee
Abstract:
The information exponent (Ben Arous et al. [2021]) -- which is equivalent to the lowest degree in the Hermite expansion of the link function for Gaussian single-index models -- has played an important role in predicting the sample complexity of online stochastic gradient descent (SGD) in various learning tasks. In this work, we demonstrate that, for multi-index models, focusing solely on the lowes…
▽ More
The information exponent (Ben Arous et al. [2021]) -- which is equivalent to the lowest degree in the Hermite expansion of the link function for Gaussian single-index models -- has played an important role in predicting the sample complexity of online stochastic gradient descent (SGD) in various learning tasks. In this work, we demonstrate that, for multi-index models, focusing solely on the lowest degree can miss key structural details of the model and result in suboptimal rates.
Specifically, we consider the task of learning target functions of form $f_*(\mathbf{x}) = \sum_{k=1}^{P} φ(\mathbf{v}_k^* \cdot \mathbf{x})$, where $P \ll d$, the ground-truth directions $\{ \mathbf{v}_k^* \}_{k=1}^P$ are orthonormal, and only the second and $2L$-th Hermite coefficients of the link function $φ$ can be nonzero. Based on the theory of information exponent, when the lowest degree is $2L$, recovering the directions requires $d^{2L-1}\mathrm{poly}(P)$ samples, and when the lowest degree is $2$, only the relevant subspace (not the exact directions) can be recovered due to the rotational invariance of the second-order terms. In contrast, we show that by considering both second- and higher-order terms, we can first learn the relevant space via the second-order terms, and then the exact directions using the higher-order terms, and the overall sample and complexity of online SGD is $d \mathrm{poly}(P)$.
△ Less
Submitted 12 October, 2024;
originally announced October 2024.
-
Scaling Laws in Linear Regression: Compute, Parameters, and Data
Authors:
Licong Lin,
Jingfeng Wu,
Sham M. Kakade,
Peter L. Bartlett,
Jason D. Lee
Abstract:
Empirically, large-scale deep learning models often satisfy a neural scaling law: the test error of the trained model improves polynomially as the model size and data size grow. However, conventional wisdom suggests the test error consists of approximation, bias, and variance errors, where the variance error increases with model size. This disagrees with the general form of neural scaling laws, wh…
▽ More
Empirically, large-scale deep learning models often satisfy a neural scaling law: the test error of the trained model improves polynomially as the model size and data size grow. However, conventional wisdom suggests the test error consists of approximation, bias, and variance errors, where the variance error increases with model size. This disagrees with the general form of neural scaling laws, which predict that increasing model size monotonically improves performance.
We study the theory of scaling laws in an infinite dimensional linear regression setup. Specifically, we consider a model with $M$ parameters as a linear function of sketched covariates. The model is trained by one-pass stochastic gradient descent (SGD) using $N$ data. Assuming the optimal parameter satisfies a Gaussian prior and the data covariance matrix has a power-law spectrum of degree $a>1$, we show that the reducible part of the test error is $Θ(M^{-(a-1)} + N^{-(a-1)/a})$. The variance error, which increases with $M$, is dominated by the other errors due to the implicit regularization of SGD, thus disappearing from the bound. Our theory is consistent with the empirical neural scaling laws and verified by numerical simulation.
△ Less
Submitted 29 October, 2024; v1 submitted 12 June, 2024;
originally announced June 2024.
-
Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot
Authors:
Zixuan Wang,
Stanley Wei,
Daniel Hsu,
Jason D. Lee
Abstract:
The transformer architecture has prevailed in various deep learning settings due to its exceptional capabilities to select and compose structural information. Motivated by these capabilities, Sanford et al. proposed the sparse token selection task, in which transformers excel while fully-connected networks (FCNs) fail in the worst case. Building upon that, we strengthen the FCN lower bound to an a…
▽ More
The transformer architecture has prevailed in various deep learning settings due to its exceptional capabilities to select and compose structural information. Motivated by these capabilities, Sanford et al. proposed the sparse token selection task, in which transformers excel while fully-connected networks (FCNs) fail in the worst case. Building upon that, we strengthen the FCN lower bound to an average-case setting and establish an algorithmic separation of transformers over FCNs. Specifically, a one-layer transformer trained with gradient descent provably learns the sparse token selection task and, surprisingly, exhibits strong out-of-distribution length generalization. We provide empirical simulations to justify our theoretical findings.
△ Less
Submitted 10 June, 2024;
originally announced June 2024.
-
Neural network learns low-dimensional polynomials with SGD near the information-theoretic limit
Authors:
Jason D. Lee,
Kazusato Oko,
Taiji Suzuki,
Denny Wu
Abstract:
We study the problem of gradient descent learning of a single-index target function $f_*(\boldsymbol{x}) = \textstyleσ_*\left(\langle\boldsymbol{x},\boldsymbolθ\rangle\right)$ under isotropic Gaussian data in $\mathbb{R}^d$, where the unknown link function $σ_*:\mathbb{R}\to\mathbb{R}$ has information exponent $p$ (defined as the lowest degree in the Hermite expansion). Prior works showed that gra…
▽ More
We study the problem of gradient descent learning of a single-index target function $f_*(\boldsymbol{x}) = \textstyleσ_*\left(\langle\boldsymbol{x},\boldsymbolθ\rangle\right)$ under isotropic Gaussian data in $\mathbb{R}^d$, where the unknown link function $σ_*:\mathbb{R}\to\mathbb{R}$ has information exponent $p$ (defined as the lowest degree in the Hermite expansion). Prior works showed that gradient-based training of neural networks can learn this target with $n\gtrsim d^{Θ(p)}$ samples, and such complexity is predicted to be necessary by the correlational statistical query lower bound. Surprisingly, we prove that a two-layer neural network optimized by an SGD-based algorithm (on the squared loss) learns $f_*$ with a complexity that is not governed by the information exponent. Specifically, for arbitrary polynomial single-index models, we establish a sample and runtime complexity of $n \simeq T = Θ(d\!\cdot\! \mathrm{polylog} d)$, where $Θ(\cdot)$ hides a constant only depending on the degree of $σ_*$; this dimension dependence matches the information theoretic limit up to polylogarithmic factors. More generally, we show that $n\gtrsim d^{(p_*-1)\vee 1}$ samples are sufficient to achieve low generalization error, where $p_* \le p$ is the \textit{generative exponent} of the link function. Core to our analysis is the reuse of minibatch in the gradient computation, which gives rise to higher-order information beyond correlational queries.
△ Less
Submitted 22 December, 2024; v1 submitted 3 June, 2024;
originally announced June 2024.
-
Computational-Statistical Gaps in Gaussian Single-Index Models
Authors:
Alex Damian,
Loucas Pillaud-Vivien,
Jason D. Lee,
Joan Bruna
Abstract:
Single-Index Models are high-dimensional regression problems with planted structure, whereby labels depend on an unknown one-dimensional projection of the input via a generic, non-linear, and potentially non-deterministic transformation. As such, they encompass a broad class of statistical inference tasks, and provide a rich template to study statistical and computational trade-offs in the high-di…
▽ More
Single-Index Models are high-dimensional regression problems with planted structure, whereby labels depend on an unknown one-dimensional projection of the input via a generic, non-linear, and potentially non-deterministic transformation. As such, they encompass a broad class of statistical inference tasks, and provide a rich template to study statistical and computational trade-offs in the high-dimensional regime.
While the information-theoretic sample complexity to recover the hidden direction is linear in the dimension $d$, we show that computationally efficient algorithms, both within the Statistical Query (SQ) and the Low-Degree Polynomial (LDP) framework, necessarily require $Ω(d^{k^\star/2})$ samples, where $k^\star$ is a "generative" exponent associated with the model that we explicitly characterize. Moreover, we show that this sample complexity is also sufficient, by establishing matching upper bounds using a partial-trace algorithm. Therefore, our results provide evidence of a sharp computational-to-statistical gap (under both the SQ and LDP class) whenever $k^\star>2$. To complete the study, we provide examples of smooth and Lipschitz deterministic target functions with arbitrarily large generative exponents $k^\star$.
△ Less
Submitted 12 March, 2024; v1 submitted 8 March, 2024;
originally announced March 2024.
-
How Well Can Transformers Emulate In-context Newton's Method?
Authors:
Angeliki Giannou,
Liu Yang,
Tianhao Wang,
Dimitris Papailiopoulos,
Jason D. Lee
Abstract:
Transformer-based models have demonstrated remarkable in-context learning capabilities, prompting extensive research into its underlying mechanisms. Recent studies have suggested that Transformers can implement first-order optimization algorithms for in-context learning and even second order ones for the case of linear regression. In this work, we study whether Transformers can perform higher orde…
▽ More
Transformer-based models have demonstrated remarkable in-context learning capabilities, prompting extensive research into its underlying mechanisms. Recent studies have suggested that Transformers can implement first-order optimization algorithms for in-context learning and even second order ones for the case of linear regression. In this work, we study whether Transformers can perform higher order optimization methods, beyond the case of linear regression. We establish that linear attention Transformers with ReLU layers can approximate second order optimization algorithms for the task of logistic regression and achieve $ε$ error with only a logarithmic to the error more layers. As a by-product we demonstrate the ability of even linear attention-only Transformers in implementing a single step of Newton's iteration for matrix inversion with merely two layers. These results suggest the ability of the Transformer architecture to implement complex algorithms, beyond gradient descent.
△ Less
Submitted 5 March, 2024;
originally announced March 2024.
-
How Transformers Learn Causal Structure with Gradient Descent
Authors:
Eshaan Nichani,
Alex Damian,
Jason D. Lee
Abstract:
The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structur…
▽ More
The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the data processing inequality, the largest entries of this gradient correspond to edges in the latent causal graph. As a special case, when the sequences are generated from in-context Markov chains, we prove that transformers learn an induction head (Olsson et al., 2022). We confirm our theoretical findings by showing that transformers trained on our in-context learning task are able to recover a wide variety of causal structures.
△ Less
Submitted 13 August, 2024; v1 submitted 22 February, 2024;
originally announced February 2024.
-
Towards Optimal Statistical Watermarking
Authors:
Baihe Huang,
Hanlin Zhu,
Banghua Zhu,
Kannan Ramchandran,
Michael I. Jordan,
Jason D. Lee,
Jiantao Jiao
Abstract:
We study statistical watermarking by formulating it as a hypothesis testing problem, a general framework which subsumes all previous statistical watermarking methods. Key to our formulation is a coupling of the output tokens and the rejection region, realized by pseudo-random generators in practice, that allows non-trivial trade-offs between the Type I error and Type II error. We characterize the…
▽ More
We study statistical watermarking by formulating it as a hypothesis testing problem, a general framework which subsumes all previous statistical watermarking methods. Key to our formulation is a coupling of the output tokens and the rejection region, realized by pseudo-random generators in practice, that allows non-trivial trade-offs between the Type I error and Type II error. We characterize the Uniformly Most Powerful (UMP) watermark in the general hypothesis testing setting and the minimax Type II error in the model-agnostic setting. In the common scenario where the output is a sequence of $n$ tokens, we establish nearly matching upper and lower bounds on the number of i.i.d. tokens required to guarantee small Type I and Type II errors. Our rate of $Θ(h^{-1} \log (1/h))$ with respect to the average entropy per token $h$ highlights potentials for improvement from the rate of $h^{-2}$ in the previous works. Moreover, we formulate the robust watermarking problem where the user is allowed to perform a class of perturbations on the generated texts, and characterize the optimal Type II error of robust UMP tests via a linear programming problem. To the best of our knowledge, this is the first systematic statistical treatment on the watermarking problem with near-optimal rates in the i.i.d. setting, which might be of interest for future works.
△ Less
Submitted 6 February, 2024; v1 submitted 13 December, 2023;
originally announced December 2023.
-
Optimal Multi-Distribution Learning
Authors:
Zihan Zhang,
Wenhao Zhan,
Yuxin Chen,
Simon S. Du,
Jason D. Lee
Abstract:
Multi-distribution learning (MDL), which seeks to learn a shared model that minimizes the worst-case risk across $k$ distinct data distributions, has emerged as a unified framework in response to the evolving demand for robustness, fairness, multi-group collaboration, etc. Achieving data-efficient MDL necessitates adaptive sampling, also called on-demand sampling, throughout the learning process.…
▽ More
Multi-distribution learning (MDL), which seeks to learn a shared model that minimizes the worst-case risk across $k$ distinct data distributions, has emerged as a unified framework in response to the evolving demand for robustness, fairness, multi-group collaboration, etc. Achieving data-efficient MDL necessitates adaptive sampling, also called on-demand sampling, throughout the learning process. However, there exist substantial gaps between the state-of-the-art upper and lower bounds on the optimal sample complexity. Focusing on a hypothesis class of Vapnik-Chervonenkis (VC) dimension d, we propose a novel algorithm that yields an varepsilon-optimal randomized hypothesis with a sample complexity on the order of (d+k)/varepsilon^2 (modulo some logarithmic factor), matching the best-known lower bound. Our algorithmic ideas and theory are further extended to accommodate Rademacher classes. The proposed algorithms are oracle-efficient, which access the hypothesis class solely through an empirical risk minimization oracle.
Additionally, we establish the necessity of randomization, revealing a large sample size barrier when only deterministic hypotheses are permitted. These findings resolve three open problems presented in COLT 2023 (i.e., citet[Problems 1, 3 and 4]{awasthi2023sample}).
△ Less
Submitted 23 May, 2024; v1 submitted 8 December, 2023;
originally announced December 2023.
-
A Probabilistic Neural Twin for Treatment Planning in Peripheral Pulmonary Artery Stenosis
Authors:
John D. Lee,
Jakob Richter,
Martin R. Pfaller,
Jason M. Szafron,
Karthik Menon,
Andrea Zanoni,
Michael R. Ma,
Jeffrey A. Feinstein,
Jacqueline Kreutzer,
Alison L. Marsden,
Daniele E. Schiavazzi
Abstract:
The substantial computational cost of high-fidelity models in numerical hemodynamics has, so far, relegated their use mainly to offline treatment planning. New breakthroughs in data-driven architectures and optimization techniques for fast surrogate modeling provide an exciting opportunity to overcome these limitations, enabling the use of such technology for time-critical decisions. We discuss an…
▽ More
The substantial computational cost of high-fidelity models in numerical hemodynamics has, so far, relegated their use mainly to offline treatment planning. New breakthroughs in data-driven architectures and optimization techniques for fast surrogate modeling provide an exciting opportunity to overcome these limitations, enabling the use of such technology for time-critical decisions. We discuss an application to the repair of multiple stenosis in peripheral pulmonary artery disease through either transcatheter pulmonary artery rehabilitation or surgery, where it is of interest to achieve desired pressures and flows at specific locations in the pulmonary artery tree, while minimizing the risk for the patient. Since different degrees of success can be achieved in practice during treatment, we formulate the problem in probability, and solve it through a sample-based approach. We propose a new offline-online pipeline for probabilsitic real-time treatment planning which combines offline assimilation of boundary conditions, model reduction, and training dataset generation with online estimation of marginal probabilities, possibly conditioned on the degree of augmentation observed in already repaired lesions. Moreover, we propose a new approach for the parametrization of arbitrarily shaped vascular repairs through iterative corrections of a zero-dimensional approximant. We demonstrate this pipeline for a diseased model of the pulmonary artery tree available through the Vascular Model Repository.
△ Less
Submitted 1 December, 2023;
originally announced December 2023.
-
Learning Hierarchical Polynomials with Three-Layer Neural Networks
Authors:
Zihao Wang,
Eshaan Nichani,
Jason D. Lee
Abstract:
We study the problem of learning hierarchical polynomials over the standard Gaussian distribution with three-layer neural networks. We specifically consider target functions of the form $h = g \circ p$ where $p : \mathbb{R}^d \rightarrow \mathbb{R}$ is a degree $k$ polynomial and $g: \mathbb{R} \rightarrow \mathbb{R}$ is a degree $q$ polynomial. This function class generalizes the single-index mod…
▽ More
We study the problem of learning hierarchical polynomials over the standard Gaussian distribution with three-layer neural networks. We specifically consider target functions of the form $h = g \circ p$ where $p : \mathbb{R}^d \rightarrow \mathbb{R}$ is a degree $k$ polynomial and $g: \mathbb{R} \rightarrow \mathbb{R}$ is a degree $q$ polynomial. This function class generalizes the single-index model, which corresponds to $k=1$, and is a natural class of functions possessing an underlying hierarchical structure. Our main result shows that for a large subclass of degree $k$ polynomials $p$, a three-layer neural network trained via layerwise gradient descent on the square loss learns the target $h$ up to vanishing test error in $\widetilde{\mathcal{O}}(d^k)$ samples and polynomial time. This is a strict improvement over kernel methods, which require $\widetilde Θ(d^{kq})$ samples, as well as existing guarantees for two-layer networks, which require the target function to be low-rank. Our result also generalizes prior works on three-layer neural networks, which were restricted to the case of $p$ being a quadratic. When $p$ is indeed a quadratic, we achieve the information-theoretically optimal sample complexity $\widetilde{\mathcal{O}}(d^2)$, which is an improvement over prior work~\citep{nichani2023provable} requiring a sample size of $\widetildeΘ(d^4)$. Our proof proceeds by showing that during the initial stage of training the network performs feature learning to recover the feature $p$ with $\widetilde{\mathcal{O}}(d^k)$ samples. This work demonstrates the ability of three-layer neural networks to learn complex features and as a result, learn a broad class of hierarchical functions.
△ Less
Submitted 22 November, 2023;
originally announced November 2023.
-
Provably Efficient CVaR RL in Low-rank MDPs
Authors:
Yulai Zhao,
Wenhao Zhan,
Xiaoyan Hu,
Ho-fung Leung,
Farzan Farnia,
Wen Sun,
Jason D. Lee
Abstract:
We study risk-sensitive Reinforcement Learning (RL), where we aim to maximize the Conditional Value at Risk (CVaR) with a fixed risk tolerance $τ$. Prior theoretical work studying risk-sensitive RL focuses on the tabular Markov Decision Processes (MDPs) setting. To extend CVaR RL to settings where state space is large, function approximation must be deployed. We study CVaR RL in low-rank MDPs with…
▽ More
We study risk-sensitive Reinforcement Learning (RL), where we aim to maximize the Conditional Value at Risk (CVaR) with a fixed risk tolerance $τ$. Prior theoretical work studying risk-sensitive RL focuses on the tabular Markov Decision Processes (MDPs) setting. To extend CVaR RL to settings where state space is large, function approximation must be deployed. We study CVaR RL in low-rank MDPs with nonlinear function approximation. Low-rank MDPs assume the underlying transition kernel admits a low-rank decomposition, but unlike prior linear models, low-rank MDPs do not assume the feature or state-action representation is known. We propose a novel Upper Confidence Bound (UCB) bonus-driven algorithm to carefully balance the interplay between exploration, exploitation, and representation learning in CVaR RL. We prove that our algorithm achieves a sample complexity of $\tilde{O}\left(\frac{H^7 A^2 d^4}{τ^2 ε^2}\right)$ to yield an $ε$-optimal CVaR, where $H$ is the length of each episode, $A$ is the capacity of action space, and $d$ is the dimension of representations. Computational-wise, we design a novel discretized Least-Squares Value Iteration (LSVI) algorithm for the CVaR objective as the planning oracle and show that we can find the near-optimal policy in a polynomial running time with a Maximum Likelihood Estimation oracle. To our knowledge, this is the first provably efficient CVaR RL algorithm in low-rank MDPs.
△ Less
Submitted 20 November, 2023;
originally announced November 2023.
-
Sample Complexity for Quadratic Bandits: Hessian Dependent Bounds and Optimal Algorithms
Authors:
Qian Yu,
Yining Wang,
Baihe Huang,
Qi Lei,
Jason D. Lee
Abstract:
In stochastic zeroth-order optimization, a problem of practical relevance is understanding how to fully exploit the local geometry of the underlying objective function. We consider a fundamental setting in which the objective function is quadratic, and provide the first tight characterization of the optimal Hessian-dependent sample complexity. Our contribution is twofold. First, from an informatio…
▽ More
In stochastic zeroth-order optimization, a problem of practical relevance is understanding how to fully exploit the local geometry of the underlying objective function. We consider a fundamental setting in which the objective function is quadratic, and provide the first tight characterization of the optimal Hessian-dependent sample complexity. Our contribution is twofold. First, from an information-theoretic point of view, we prove tight lower bounds on Hessian-dependent complexities by introducing a concept called energy allocation, which captures the interaction between the searching algorithm and the geometry of objective functions. A matching upper bound is obtained by solving the optimal energy spectrum. Then, algorithmically, we show the existence of a Hessian-independent algorithm that universally achieves the asymptotic optimal sample complexities for all Hessian instances. The optimal sample complexities achieved by our algorithm remain valid for heavy-tailed noise distributions, which are enabled by a truncation method.
△ Less
Submitted 25 December, 2023; v1 submitted 21 June, 2023;
originally announced June 2023.
-
Provable Reward-Agnostic Preference-Based Reinforcement Learning
Authors:
Wenhao Zhan,
Masatoshi Uehara,
Wen Sun,
Jason D. Lee
Abstract:
Preference-based Reinforcement Learning (PbRL) is a paradigm in which an RL agent learns to optimize a task using pair-wise preference-based feedback over trajectories, rather than explicit reward signals. While PbRL has demonstrated practical success in fine-tuning language models, existing theoretical work focuses on regret minimization and fails to capture most of the practical frameworks. In t…
▽ More
Preference-based Reinforcement Learning (PbRL) is a paradigm in which an RL agent learns to optimize a task using pair-wise preference-based feedback over trajectories, rather than explicit reward signals. While PbRL has demonstrated practical success in fine-tuning language models, existing theoretical work focuses on regret minimization and fails to capture most of the practical frameworks. In this study, we fill in such a gap between theoretical PbRL and practical algorithms by proposing a theoretical reward-agnostic PbRL framework where exploratory trajectories that enable accurate learning of hidden reward functions are acquired before collecting any human feedback. Theoretical analysis demonstrates that our algorithm requires less human feedback for learning the optimal policy under preference-based models with linear parameterization and unknown transitions, compared to the existing theoretical literature. Specifically, our framework can incorporate linear and low-rank MDPs with efficient sample complexity. Additionally, we investigate reward-agnostic RL with action-based comparison feedback and introduce an efficient querying algorithm tailored to this scenario.
△ Less
Submitted 17 April, 2024; v1 submitted 29 May, 2023;
originally announced May 2023.
-
Reward Collapse in Aligning Large Language Models
Authors:
Ziang Song,
Tianle Cai,
Jason D. Lee,
Weijie J. Su
Abstract:
The extraordinary capabilities of large language models (LLMs) such as ChatGPT and GPT-4 are in part unleashed by aligning them with reward models that are trained on human preferences, which are often represented as rankings of responses to prompts. In this paper, we document the phenomenon of \textit{reward collapse}, an empirical observation where the prevailing ranking-based approach results i…
▽ More
The extraordinary capabilities of large language models (LLMs) such as ChatGPT and GPT-4 are in part unleashed by aligning them with reward models that are trained on human preferences, which are often represented as rankings of responses to prompts. In this paper, we document the phenomenon of \textit{reward collapse}, an empirical observation where the prevailing ranking-based approach results in an \textit{identical} reward distribution \textit{regardless} of the prompts during the terminal phase of training. This outcome is undesirable as open-ended prompts like ``write a short story about your best friend'' should yield a continuous range of rewards for their completions, while specific prompts like ``what is the capital of New Zealand'' should generate either high or low rewards. Our theoretical investigation reveals that reward collapse is primarily due to the insufficiency of the ranking-based objective function to incorporate prompt-related information during optimization. This insight allows us to derive closed-form expressions for the reward distribution associated with a set of utility functions in an asymptotic regime. To overcome reward collapse, we introduce a prompt-aware optimization scheme that provably admits a prompt-dependent reward distribution within the interpolating regime. Our experimental results suggest that our proposed prompt-aware utility functions significantly alleviate reward collapse during the training of reward models.
△ Less
Submitted 27 May, 2023;
originally announced May 2023.
-
Provable Offline Preference-Based Reinforcement Learning
Authors:
Wenhao Zhan,
Masatoshi Uehara,
Nathan Kallus,
Jason D. Lee,
Wen Sun
Abstract:
In this paper, we investigate the problem of offline Preference-based Reinforcement Learning (PbRL) with human feedback where feedback is available in the form of preference between trajectory pairs rather than explicit rewards. Our proposed algorithm consists of two main steps: (1) estimate the implicit reward using Maximum Likelihood Estimation (MLE) with general function approximation from offl…
▽ More
In this paper, we investigate the problem of offline Preference-based Reinforcement Learning (PbRL) with human feedback where feedback is available in the form of preference between trajectory pairs rather than explicit rewards. Our proposed algorithm consists of two main steps: (1) estimate the implicit reward using Maximum Likelihood Estimation (MLE) with general function approximation from offline data and (2) solve a distributionally robust planning problem over a confidence set around the MLE. We consider the general reward setting where the reward can be defined over the whole trajectory and provide a novel guarantee that allows us to learn any target policy with a polynomial number of samples, as long as the target policy is covered by the offline data. This guarantee is the first of its kind with general function approximation. To measure the coverage of the target policy, we introduce a new single-policy concentrability coefficient, which can be upper bounded by the per-trajectory concentrability coefficient. We also establish lower bounds that highlight the necessity of such concentrability and the difference from standard RL, where state-action-wise rewards are directly observed. We further extend and analyze our algorithm when the feedback is given over action pairs.
△ Less
Submitted 29 September, 2023; v1 submitted 24 May, 2023;
originally announced May 2023.
-
Implicit Bias of Gradient Descent for Logistic Regression at the Edge of Stability
Authors:
Jingfeng Wu,
Vladimir Braverman,
Jason D. Lee
Abstract:
Recent research has observed that in machine learning optimization, gradient descent (GD) often operates at the edge of stability (EoS) [Cohen, et al., 2021], where the stepsizes are set to be large, resulting in non-monotonic losses induced by the GD iterates. This paper studies the convergence and implicit bias of constant-stepsize GD for logistic regression on linearly separable data in the EoS…
▽ More
Recent research has observed that in machine learning optimization, gradient descent (GD) often operates at the edge of stability (EoS) [Cohen, et al., 2021], where the stepsizes are set to be large, resulting in non-monotonic losses induced by the GD iterates. This paper studies the convergence and implicit bias of constant-stepsize GD for logistic regression on linearly separable data in the EoS regime. Despite the presence of local oscillations, we prove that the logistic loss can be minimized by GD with \emph{any} constant stepsize over a long time scale. Furthermore, we prove that with \emph{any} constant stepsize, the GD iterates tend to infinity when projected to a max-margin direction (the hard-margin SVM direction) and converge to a fixed vector that minimizes a strongly convex potential when projected to the orthogonal complement of the max-margin direction. In contrast, we also show that in the EoS regime, GD iterates may diverge catastrophically under the exponential loss, highlighting the superiority of the logistic loss. These theoretical findings are in line with numerical simulations and complement existing theories on the convergence and implicit bias of GD for logistic regression, which are only applicable when the stepsizes are sufficiently small.
△ Less
Submitted 15 October, 2023; v1 submitted 19 May, 2023;
originally announced May 2023.
-
Smoothing the Landscape Boosts the Signal for SGD: Optimal Sample Complexity for Learning Single Index Models
Authors:
Alex Damian,
Eshaan Nichani,
Rong Ge,
Jason D. Lee
Abstract:
We focus on the task of learning a single index model $σ(w^\star \cdot x)$ with respect to the isotropic Gaussian distribution in $d$ dimensions. Prior work has shown that the sample complexity of learning $w^\star$ is governed by the information exponent $k^\star$ of the link function $σ$, which is defined as the index of the first nonzero Hermite coefficient of $σ$. Ben Arous et al. (2021) showe…
▽ More
We focus on the task of learning a single index model $σ(w^\star \cdot x)$ with respect to the isotropic Gaussian distribution in $d$ dimensions. Prior work has shown that the sample complexity of learning $w^\star$ is governed by the information exponent $k^\star$ of the link function $σ$, which is defined as the index of the first nonzero Hermite coefficient of $σ$. Ben Arous et al. (2021) showed that $n \gtrsim d^{k^\star-1}$ samples suffice for learning $w^\star$ and that this is tight for online SGD. However, the CSQ lower bound for gradient based methods only shows that $n \gtrsim d^{k^\star/2}$ samples are necessary. In this work, we close the gap between the upper and lower bounds by showing that online SGD on a smoothed loss learns $w^\star$ with $n \gtrsim d^{k^\star/2}$ samples. We also draw connections to statistical analyses of tensor PCA and to the implicit regularization effects of minibatch SGD on empirical losses.
△ Less
Submitted 17 May, 2023;
originally announced May 2023.
-
Reward-agnostic Fine-tuning: Provable Statistical Benefits of Hybrid Reinforcement Learning
Authors:
Gen Li,
Wenhao Zhan,
Jason D. Lee,
Yuejie Chi,
Yuxin Chen
Abstract:
This paper studies tabular reinforcement learning (RL) in the hybrid setting, which assumes access to both an offline dataset and online interactions with the unknown environment. A central question boils down to how to efficiently utilize online data collection to strengthen and complement the offline dataset and enable effective policy fine-tuning. Leveraging recent advances in reward-agnostic e…
▽ More
This paper studies tabular reinforcement learning (RL) in the hybrid setting, which assumes access to both an offline dataset and online interactions with the unknown environment. A central question boils down to how to efficiently utilize online data collection to strengthen and complement the offline dataset and enable effective policy fine-tuning. Leveraging recent advances in reward-agnostic exploration and model-based offline RL, we design a three-stage hybrid RL algorithm that beats the best of both worlds -- pure offline RL and pure online RL -- in terms of sample complexities. The proposed algorithm does not require any reward information during data collection. Our theory is developed based on a new notion called single-policy partial concentrability, which captures the trade-off between distribution mismatch and miscoverage and guides the interplay between offline and online data.
△ Less
Submitted 17 May, 2023;
originally announced May 2023.
-
Provable Guarantees for Nonlinear Feature Learning in Three-Layer Neural Networks
Authors:
Eshaan Nichani,
Alex Damian,
Jason D. Lee
Abstract:
One of the central questions in the theory of deep learning is to understand how neural networks learn hierarchical features. The ability of deep networks to extract salient features is crucial to both their outstanding generalization ability and the modern deep learning paradigm of pretraining and finetuneing. However, this feature learning process remains poorly understood from a theoretical per…
▽ More
One of the central questions in the theory of deep learning is to understand how neural networks learn hierarchical features. The ability of deep networks to extract salient features is crucial to both their outstanding generalization ability and the modern deep learning paradigm of pretraining and finetuneing. However, this feature learning process remains poorly understood from a theoretical perspective, with existing analyses largely restricted to two-layer networks. In this work we show that three-layer neural networks have provably richer feature learning capabilities than two-layer networks. We analyze the features learned by a three-layer network trained with layer-wise gradient descent, and present a general purpose theorem which upper bounds the sample complexity and width needed to achieve low test error when the target has specific hierarchical structure. We instantiate our framework in specific statistical learning settings -- single-index models and functions of quadratic features -- and show that in the latter setting three-layer networks obtain a sample complexity improvement over all existing guarantees for two-layer networks. Crucially, this sample complexity improvement relies on the ability of three-layer networks to efficiently learn nonlinear features. We then establish a concrete optimization-based depth separation by constructing a function which is efficiently learnable via gradient descent on a three-layer network, yet cannot be learned efficiently by a two-layer network. Our work makes progress towards understanding the provable benefit of three-layer neural networks over two-layer networks in the feature learning regime.
△ Less
Submitted 1 April, 2025; v1 submitted 11 May, 2023;
originally announced May 2023.
-
Local Optimization Achieves Global Optimality in Multi-Agent Reinforcement Learning
Authors:
Yulai Zhao,
Zhuoran Yang,
Zhaoran Wang,
Jason D. Lee
Abstract:
Policy optimization methods with function approximation are widely used in multi-agent reinforcement learning. However, it remains elusive how to design such algorithms with statistical guarantees. Leveraging a multi-agent performance difference lemma that characterizes the landscape of multi-agent policy optimization, we find that the localized action value function serves as an ideal descent dir…
▽ More
Policy optimization methods with function approximation are widely used in multi-agent reinforcement learning. However, it remains elusive how to design such algorithms with statistical guarantees. Leveraging a multi-agent performance difference lemma that characterizes the landscape of multi-agent policy optimization, we find that the localized action value function serves as an ideal descent direction for each local policy. Motivated by the observation, we present a multi-agent PPO algorithm in which the local policy of each agent is updated similarly to vanilla PPO. We prove that with standard regularity conditions on the Markov game and problem-dependent quantities, our algorithm converges to the globally optimal policy at a sublinear rate. We extend our algorithm to the off-policy setting and introduce pessimism to policy evaluation, which aligns with experiments. To our knowledge, this is the first provably convergent multi-agent PPO algorithm in cooperative Markov games.
△ Less
Submitted 8 May, 2023;
originally announced May 2023.
-
Efficient displacement convex optimization with particle gradient descent
Authors:
Hadi Daneshmand,
Jason D. Lee,
Chi Jin
Abstract:
Particle gradient descent, which uses particles to represent a probability measure and performs gradient descent on particles in parallel, is widely used to optimize functions of probability measures. This paper considers particle gradient descent with a finite number of particles and establishes its theoretical guarantees to optimize functions that are \emph{displacement convex} in measures. Conc…
▽ More
Particle gradient descent, which uses particles to represent a probability measure and performs gradient descent on particles in parallel, is widely used to optimize functions of probability measures. This paper considers particle gradient descent with a finite number of particles and establishes its theoretical guarantees to optimize functions that are \emph{displacement convex} in measures. Concretely, for Lipschitz displacement convex functions defined on probability over $\mathbb{R}^d$, we prove that $O(1/ε^2)$ particles and $O(d/ε^4)$ computations are sufficient to find the $ε$-optimal solutions. We further provide improved complexity bounds for optimizing smooth displacement convex functions. We demonstrate the application of our results for function approximation with specific neural architectures with two-dimensional inputs.
△ Less
Submitted 9 February, 2023;
originally announced February 2023.
-
Offline Minimax Soft-Q-learning Under Realizability and Partial Coverage
Authors:
Masatoshi Uehara,
Nathan Kallus,
Jason D. Lee,
Wen Sun
Abstract:
In offline reinforcement learning (RL) we have no opportunity to explore so we must make assumptions that the data is sufficient to guide picking a good policy, taking the form of assuming some coverage, realizability, Bellman completeness, and/or hard margin (gap). In this work we propose value-based algorithms for offline RL with PAC guarantees under just partial coverage, specifically, coverage…
▽ More
In offline reinforcement learning (RL) we have no opportunity to explore so we must make assumptions that the data is sufficient to guide picking a good policy, taking the form of assuming some coverage, realizability, Bellman completeness, and/or hard margin (gap). In this work we propose value-based algorithms for offline RL with PAC guarantees under just partial coverage, specifically, coverage of just a single comparator policy, and realizability of soft (entropy-regularized) Q-function of the single policy and a related function defined as a saddle point of certain minimax optimization problem. This offers refined and generally more lax conditions for offline RL. We further show an analogous result for vanilla Q-functions under a soft margin condition. To attain these guarantees, we leverage novel minimax learning algorithms to accurately estimate soft or vanilla Q-functions with $L^2$-convergence guarantees. Our algorithms' loss functions arise from casting the estimation problems as nonlinear convex optimization problems and Lagrangifying.
△ Less
Submitted 13 November, 2023; v1 submitted 5 February, 2023;
originally announced February 2023.
-
Understanding Incremental Learning of Gradient Descent: A Fine-grained Analysis of Matrix Sensing
Authors:
Jikai Jin,
Zhiyuan Li,
Kaifeng Lyu,
Simon S. Du,
Jason D. Lee
Abstract:
It is believed that Gradient Descent (GD) induces an implicit bias towards good generalization in training machine learning models. This paper provides a fine-grained analysis of the dynamics of GD for the matrix sensing problem, whose goal is to recover a low-rank ground-truth matrix from near-isotropic linear measurements. It is shown that GD with small initialization behaves similarly to the gr…
▽ More
It is believed that Gradient Descent (GD) induces an implicit bias towards good generalization in training machine learning models. This paper provides a fine-grained analysis of the dynamics of GD for the matrix sensing problem, whose goal is to recover a low-rank ground-truth matrix from near-isotropic linear measurements. It is shown that GD with small initialization behaves similarly to the greedy low-rank learning heuristics (Li et al., 2020) and follows an incremental learning procedure (Gissin et al., 2019): GD sequentially learns solutions with increasing ranks until it recovers the ground truth matrix. Compared to existing works which only analyze the first learning phase for rank-1 solutions, our result provides characterizations for the whole learning process. Moreover, besides the over-parameterized regime that many prior works focused on, our analysis of the incremental learning procedure also applies to the under-parameterized regime. Finally, we conduct numerical experiments to confirm our theoretical findings.
△ Less
Submitted 26 January, 2023;
originally announced January 2023.
-
Reconstructing Training Data from Model Gradient, Provably
Authors:
Zihan Wang,
Jason D. Lee,
Qi Lei
Abstract:
Understanding when and how much a model gradient leaks information about the training sample is an important question in privacy. In this paper, we present a surprising result: even without training or memorizing the data, we can fully reconstruct the training samples from a single gradient query at a randomly chosen parameter value. We prove the identifiability of the training data under mild con…
▽ More
Understanding when and how much a model gradient leaks information about the training sample is an important question in privacy. In this paper, we present a surprising result: even without training or memorizing the data, we can fully reconstruct the training samples from a single gradient query at a randomly chosen parameter value. We prove the identifiability of the training data under mild conditions: with shallow or deep neural networks and a wide range of activation functions. We also present a statistically and computationally efficient algorithm based on tensor decomposition to reconstruct the training data. As a provable attack that reveals sensitive training data, our findings suggest potential severe threats to privacy, especially in federated learning.
△ Less
Submitted 10 June, 2023; v1 submitted 7 December, 2022;
originally announced December 2022.
-
Self-Stabilization: The Implicit Bias of Gradient Descent at the Edge of Stability
Authors:
Alex Damian,
Eshaan Nichani,
Jason D. Lee
Abstract:
Traditional analyses of gradient descent show that when the largest eigenvalue of the Hessian, also known as the sharpness $S(θ)$, is bounded by $2/η$, training is "stable" and the training loss decreases monotonically. Recent works, however, have observed that this assumption does not hold when training modern neural networks with full batch or large batch gradient descent. Most recently, Cohen e…
▽ More
Traditional analyses of gradient descent show that when the largest eigenvalue of the Hessian, also known as the sharpness $S(θ)$, is bounded by $2/η$, training is "stable" and the training loss decreases monotonically. Recent works, however, have observed that this assumption does not hold when training modern neural networks with full batch or large batch gradient descent. Most recently, Cohen et al. (2021) observed two important phenomena. The first, dubbed progressive sharpening, is that the sharpness steadily increases throughout training until it reaches the instability cutoff $2/η$. The second, dubbed edge of stability, is that the sharpness hovers at $2/η$ for the remainder of training while the loss continues decreasing, albeit non-monotonically. We demonstrate that, far from being chaotic, the dynamics of gradient descent at the edge of stability can be captured by a cubic Taylor expansion: as the iterates diverge in direction of the top eigenvector of the Hessian due to instability, the cubic term in the local Taylor expansion of the loss function causes the curvature to decrease until stability is restored. This property, which we call self-stabilization, is a general property of gradient descent and explains its behavior at the edge of stability. A key consequence of self-stabilization is that gradient descent at the edge of stability implicitly follows projected gradient descent (PGD) under the constraint $S(θ) \le 2/η$. Our analysis provides precise predictions for the loss, sharpness, and deviation from the PGD trajectory throughout training, which we verify both empirically in a number of standard settings and theoretically under mild conditions. Our analysis uncovers the mechanism for gradient descent's implicit bias towards stability.
△ Less
Submitted 10 April, 2023; v1 submitted 30 September, 2022;
originally announced September 2022.
-
Neural Networks can Learn Representations with Gradient Descent
Authors:
Alex Damian,
Jason D. Lee,
Mahdi Soltanolkotabi
Abstract:
Significant theoretical work has established that in specific regimes, neural networks trained by gradient descent behave like kernel methods. However, in practice, it is known that neural networks strongly outperform their associated kernels. In this work, we explain this gap by demonstrating that there is a large class of functions which cannot be efficiently learned by kernel methods but can be…
▽ More
Significant theoretical work has established that in specific regimes, neural networks trained by gradient descent behave like kernel methods. However, in practice, it is known that neural networks strongly outperform their associated kernels. In this work, we explain this gap by demonstrating that there is a large class of functions which cannot be efficiently learned by kernel methods but can be easily learned with gradient descent on a two layer neural network outside the kernel regime by learning representations that are relevant to the target task. We also demonstrate that these representations allow for efficient transfer learning, which is impossible in the kernel regime.
Specifically, we consider the problem of learning polynomials which depend on only a few relevant directions, i.e. of the form $f^\star(x) = g(Ux)$ where $U: \R^d \to \R^r$ with $d \gg r$. When the degree of $f^\star$ is $p$, it is known that $n \asymp d^p$ samples are necessary to learn $f^\star$ in the kernel regime. Our primary result is that gradient descent learns a representation of the data which depends only on the directions relevant to $f^\star$. This results in an improved sample complexity of $n\asymp d^2 r + dr^p$. Furthermore, in a transfer learning setup where the data distributions in the source and target domain share the same representation $U$ but have different polynomial heads we show that a popular heuristic for transfer learning has a target sample complexity independent of $d$.
△ Less
Submitted 30 June, 2022;
originally announced June 2022.
-
Computationally Efficient PAC RL in POMDPs with Latent Determinism and Conditional Embeddings
Authors:
Masatoshi Uehara,
Ayush Sekhari,
Jason D. Lee,
Nathan Kallus,
Wen Sun
Abstract:
We study reinforcement learning with function approximation for large-scale Partially Observable Markov Decision Processes (POMDPs) where the state space and observation space are large or even continuous. Particularly, we consider Hilbert space embeddings of POMDP where the feature of latent states and the feature of observations admit a conditional Hilbert space embedding of the observation emis…
▽ More
We study reinforcement learning with function approximation for large-scale Partially Observable Markov Decision Processes (POMDPs) where the state space and observation space are large or even continuous. Particularly, we consider Hilbert space embeddings of POMDP where the feature of latent states and the feature of observations admit a conditional Hilbert space embedding of the observation emission process, and the latent state transition is deterministic. Under the function approximation setup where the optimal latent state-action $Q$-function is linear in the state feature, and the optimal $Q$-function has a gap in actions, we provide a \emph{computationally and statistically efficient} algorithm for finding the \emph{exact optimal} policy. We show our algorithm's computational and statistical complexities scale polynomially with respect to the horizon and the intrinsic dimension of the feature on the observation space. Furthermore, we show both the deterministic latent transitions and gap assumptions are necessary to avoid statistical complexity exponential in horizon or dimension. Since our guarantee does not have an explicit dependence on the size of the state and observation spaces, our algorithm provably scales to large-scale POMDPs.
△ Less
Submitted 24 June, 2022;
originally announced June 2022.
-
Provably Efficient Reinforcement Learning in Partially Observable Dynamical Systems
Authors:
Masatoshi Uehara,
Ayush Sekhari,
Jason D. Lee,
Nathan Kallus,
Wen Sun
Abstract:
We study Reinforcement Learning for partially observable dynamical systems using function approximation. We propose a new \textit{Partially Observable Bilinear Actor-Critic framework}, that is general enough to include models such as observable tabular Partially Observable Markov Decision Processes (POMDPs), observable Linear-Quadratic-Gaussian (LQG), Predictive State Representations (PSRs), as we…
▽ More
We study Reinforcement Learning for partially observable dynamical systems using function approximation. We propose a new \textit{Partially Observable Bilinear Actor-Critic framework}, that is general enough to include models such as observable tabular Partially Observable Markov Decision Processes (POMDPs), observable Linear-Quadratic-Gaussian (LQG), Predictive State Representations (PSRs), as well as a newly introduced model Hilbert Space Embeddings of POMDPs and observable POMDPs with latent low-rank transition. Under this framework, we propose an actor-critic style algorithm that is capable of performing agnostic policy learning. Given a policy class that consists of memory based policies (that look at a fixed-length window of recent observations), and a value function class that consists of functions taking both memory and future observations as inputs, our algorithm learns to compete against the best memory-based policy in the given policy class. For certain examples such as undercomplete observable tabular POMDPs, observable LQGs and observable POMDPs with latent low-rank transition, by implicitly leveraging their special properties, our algorithm is even capable of competing against the globally optimal policy without paying an exponential dependence on the horizon in its sample complexity.
△ Less
Submitted 23 June, 2022;
originally announced June 2022.
-
Identifying good directions to escape the NTK regime and efficiently learn low-degree plus sparse polynomials
Authors:
Eshaan Nichani,
Yu Bai,
Jason D. Lee
Abstract:
A recent goal in the theory of deep learning is to identify how neural networks can escape the "lazy training," or Neural Tangent Kernel (NTK) regime, where the network is coupled with its first order Taylor expansion at initialization. While the NTK is minimax optimal for learning dense polynomials (Ghorbani et al, 2021), it cannot learn features, and hence has poor sample complexity for learning…
▽ More
A recent goal in the theory of deep learning is to identify how neural networks can escape the "lazy training," or Neural Tangent Kernel (NTK) regime, where the network is coupled with its first order Taylor expansion at initialization. While the NTK is minimax optimal for learning dense polynomials (Ghorbani et al, 2021), it cannot learn features, and hence has poor sample complexity for learning many classes of functions including sparse polynomials. Recent works have thus aimed to identify settings where gradient based algorithms provably generalize better than the NTK. One such example is the "QuadNTK" approach of Bai and Lee (2020), which analyzes the second-order term in the Taylor expansion. Bai and Lee (2020) show that the second-order term can learn sparse polynomials efficiently; however, it sacrifices the ability to learn general dense polynomials.
In this paper, we analyze how gradient descent on a two-layer neural network can escape the NTK regime by utilizing a spectral characterization of the NTK (Montanari and Zhong, 2020) and building on the QuadNTK approach. We first expand upon the spectral analysis to identify "good" directions in parameter space in which we can move without harming generalization. Next, we show that a wide two-layer neural network can jointly use the NTK and QuadNTK to fit target functions consisting of a dense low-degree term and a sparse high-degree term -- something neither the NTK nor the QuadNTK can do on their own. Finally, we construct a regularizer which encourages our parameter vector to move in the "good" directions, and show that gradient descent on the regularized loss will converge to a global minimizer, which also has low test error. This yields an end to end convergence and generalization guarantee with provable sample complexity improvement over both the NTK and QuadNTK on their own.
△ Less
Submitted 26 November, 2022; v1 submitted 8 June, 2022;
originally announced June 2022.
-
Decentralized Optimistic Hyperpolicy Mirror Descent: Provably No-Regret Learning in Markov Games
Authors:
Wenhao Zhan,
Jason D. Lee,
Zhuoran Yang
Abstract:
We study decentralized policy learning in Markov games where we control a single agent to play with nonstationary and possibly adversarial opponents. Our goal is to develop a no-regret online learning algorithm that (i) takes actions based on the local information observed by the agent and (ii) is able to find the best policy in hindsight. For such a problem, the nonstationary state transitions du…
▽ More
We study decentralized policy learning in Markov games where we control a single agent to play with nonstationary and possibly adversarial opponents. Our goal is to develop a no-regret online learning algorithm that (i) takes actions based on the local information observed by the agent and (ii) is able to find the best policy in hindsight. For such a problem, the nonstationary state transitions due to the varying opponent pose a significant challenge. In light of a recent hardness result \citep{liu2022learning}, we focus on the setting where the opponent's previous policies are revealed to the agent for decision making. With such an information structure, we propose a new algorithm, \underline{D}ecentralized \underline{O}ptimistic hype\underline{R}policy m\underline{I}rror de\underline{S}cent (DORIS), which achieves $\sqrt{K}$-regret in the context of general function approximation, where $K$ is the number of episodes. Moreover, when all the agents adopt DORIS, we prove that their mixture policy constitutes an approximate coarse correlated equilibrium. In particular, DORIS maintains a \textit{hyperpolicy} which is a distribution over the policy space. The hyperpolicy is updated via mirror descent, where the update direction is obtained by an optimistic variant of least-squares policy evaluation. Furthermore, to illustrate the power of our method, we apply DORIS to constrained and vector-valued MDPs, which can be formulated as zero-sum Markov games with a fictitious opponent.
△ Less
Submitted 3 June, 2022;
originally announced June 2022.
-
On the Effective Number of Linear Regions in Shallow Univariate ReLU Networks: Convergence Guarantees and Implicit Bias
Authors:
Itay Safran,
Gal Vardi,
Jason D. Lee
Abstract:
We study the dynamics and implicit bias of gradient flow (GF) on univariate ReLU neural networks with a single hidden layer in a binary classification setting. We show that when the labels are determined by the sign of a target network with $r$ neurons, with high probability over the initialization of the network and the sampling of the dataset, GF converges in direction (suitably defined) to a ne…
▽ More
We study the dynamics and implicit bias of gradient flow (GF) on univariate ReLU neural networks with a single hidden layer in a binary classification setting. We show that when the labels are determined by the sign of a target network with $r$ neurons, with high probability over the initialization of the network and the sampling of the dataset, GF converges in direction (suitably defined) to a network achieving perfect training accuracy and having at most $\mathcal{O}(r)$ linear regions, implying a generalization bound. Unlike many other results in the literature, under an additional assumption on the distribution of the data, our result holds even for mild over-parameterization, where the width is $\tilde{\mathcal{O}}(r)$ and independent of the sample size.
△ Less
Submitted 2 February, 2023; v1 submitted 18 May, 2022;
originally announced May 2022.
-
Nearly Minimax Algorithms for Linear Bandits with Shared Representation
Authors:
Jiaqi Yang,
Qi Lei,
Jason D. Lee,
Simon S. Du
Abstract:
We give novel algorithms for multi-task and lifelong linear bandits with shared representation. Specifically, we consider the setting where we play $M$ linear bandits with dimension $d$, each for $T$ rounds, and these $M$ bandit tasks share a common $k(\ll d)$ dimensional linear representation. For both the multi-task setting where we play the tasks concurrently, and the lifelong setting where we…
▽ More
We give novel algorithms for multi-task and lifelong linear bandits with shared representation. Specifically, we consider the setting where we play $M$ linear bandits with dimension $d$, each for $T$ rounds, and these $M$ bandit tasks share a common $k(\ll d)$ dimensional linear representation. For both the multi-task setting where we play the tasks concurrently, and the lifelong setting where we play tasks sequentially, we come up with novel algorithms that achieve $\widetilde{O}\left(d\sqrt{kMT} + kM\sqrt{T}\right)$ regret bounds, which matches the known minimax regret lower bound up to logarithmic factors and closes the gap in existing results [Yang et al., 2021]. Our main technique include a more efficient estimator for the low-rank linear feature extractor and an accompanied novel analysis for this estimator.
△ Less
Submitted 29 March, 2022;
originally announced March 2022.
-
Offline Reinforcement Learning with Realizability and Single-policy Concentrability
Authors:
Wenhao Zhan,
Baihe Huang,
Audrey Huang,
Nan Jiang,
Jason D. Lee
Abstract:
Sample-efficiency guarantees for offline reinforcement learning (RL) often rely on strong assumptions on both the function classes (e.g., Bellman-completeness) and the data coverage (e.g., all-policy concentrability). Despite the recent efforts on relaxing these assumptions, existing works are only able to relax one of the two factors, leaving the strong assumption on the other factor intact. As a…
▽ More
Sample-efficiency guarantees for offline reinforcement learning (RL) often rely on strong assumptions on both the function classes (e.g., Bellman-completeness) and the data coverage (e.g., all-policy concentrability). Despite the recent efforts on relaxing these assumptions, existing works are only able to relax one of the two factors, leaving the strong assumption on the other factor intact. As an important open problem, can we achieve sample-efficient offline RL with weak assumptions on both factors?
In this paper we answer the question in the positive. We analyze a simple algorithm based on the primal-dual formulation of MDPs, where the dual variables (discounted occupancy) are modeled using a density-ratio function against offline data. With proper regularization, we show that the algorithm enjoys polynomial sample complexity, under only realizability and single-policy concentrability. We also provide alternative analyses based on different assumptions to shed light on the nature of primal-dual algorithms for offline RL.
△ Less
Submitted 27 June, 2022; v1 submitted 9 February, 2022;
originally announced February 2022.
-
Provable Hierarchy-Based Meta-Reinforcement Learning
Authors:
Kurtland Chua,
Qi Lei,
Jason D. Lee
Abstract:
Hierarchical reinforcement learning (HRL) has seen widespread interest as an approach to tractable learning of complex modular behaviors. However, existing work either assume access to expert-constructed hierarchies, or use hierarchy-learning heuristics with no provable guarantees. To address this gap, we analyze HRL in the meta-RL setting, where a learner learns latent hierarchical structure duri…
▽ More
Hierarchical reinforcement learning (HRL) has seen widespread interest as an approach to tractable learning of complex modular behaviors. However, existing work either assume access to expert-constructed hierarchies, or use hierarchy-learning heuristics with no provable guarantees. To address this gap, we analyze HRL in the meta-RL setting, where a learner learns latent hierarchical structure during meta-training for use in a downstream task. We consider a tabular setting where natural hierarchical structure is embedded in the transition dynamics. Analogous to supervised meta-learning theory, we provide "diversity conditions" which, together with a tractable optimism-based algorithm, guarantee sample-efficient recovery of this natural hierarchy. Furthermore, we provide regret bounds on a learner using the recovered hierarchy to solve a meta-test task. Our bounds incorporate common notions in HRL literature such as temporal and state/action abstractions, suggesting that our setting and analysis capture important features of HRL in practice.
△ Less
Submitted 18 October, 2021;
originally announced October 2021.
-
Towards General Function Approximation in Zero-Sum Markov Games
Authors:
Baihe Huang,
Jason D. Lee,
Zhaoran Wang,
Zhuoran Yang
Abstract:
This paper considers two-player zero-sum finite-horizon Markov games with simultaneous moves. The study focuses on the challenging settings where the value function or the model is parameterized by general function classes. Provably efficient algorithms for both decoupled and {coordinated} settings are developed. In the {decoupled} setting where the agent controls a single player and plays against…
▽ More
This paper considers two-player zero-sum finite-horizon Markov games with simultaneous moves. The study focuses on the challenging settings where the value function or the model is parameterized by general function classes. Provably efficient algorithms for both decoupled and {coordinated} settings are developed. In the {decoupled} setting where the agent controls a single player and plays against an arbitrary opponent, we propose a new model-free algorithm. The sample complexity is governed by the Minimax Eluder dimension -- a new dimension of the function class in Markov games. As a special case, this method improves the state-of-the-art algorithm by a $\sqrt{d}$ factor in the regret when the reward function and transition kernel are parameterized with $d$-dimensional linear features. In the {coordinated} setting where both players are controlled by the agent, we propose a model-based algorithm and a model-free algorithm. In the model-based algorithm, we prove that sample complexity can be bounded by a generalization of Witness rank to Markov games. The model-free algorithm enjoys a $\sqrt{K}$-regret upper bound where $K$ is the number of episodes.
△ Less
Submitted 30 October, 2021; v1 submitted 30 July, 2021;
originally announced July 2021.
-
Going Beyond Linear RL: Sample Efficient Neural Function Approximation
Authors:
Baihe Huang,
Kaixuan Huang,
Sham M. Kakade,
Jason D. Lee,
Qi Lei,
Runzhe Wang,
Jiaqi Yang
Abstract:
Deep Reinforcement Learning (RL) powered by neural net approximation of the Q function has had enormous empirical success. While the theory of RL has traditionally focused on linear function approximation (or eluder dimension) approaches, little is known about nonlinear RL with neural net approximations of the Q functions. This is the focus of this work, where we study function approximation with…
▽ More
Deep Reinforcement Learning (RL) powered by neural net approximation of the Q function has had enormous empirical success. While the theory of RL has traditionally focused on linear function approximation (or eluder dimension) approaches, little is known about nonlinear RL with neural net approximations of the Q functions. This is the focus of this work, where we study function approximation with two-layer neural networks (considering both ReLU and polynomial activation functions). Our first result is a computationally and statistically efficient algorithm in the generative model setting under completeness for two-layer neural networks. Our second result considers this setting but under only realizability of the neural net function class. Here, assuming deterministic dynamics, the sample complexity scales linearly in the algebraic dimension. In all cases, our results significantly improve upon what can be attained with linear (or eluder dimension) methods.
△ Less
Submitted 25 December, 2021; v1 submitted 13 July, 2021;
originally announced July 2021.
-
Optimal Gradient-based Algorithms for Non-concave Bandit Optimization
Authors:
Baihe Huang,
Kaixuan Huang,
Sham M. Kakade,
Jason D. Lee,
Qi Lei,
Runzhe Wang,
Jiaqi Yang
Abstract:
Bandit problems with linear or concave reward have been extensively studied, but relatively few works have studied bandits with non-concave reward. This work considers a large family of bandit problems where the unknown underlying reward function is non-concave, including the low-rank generalized linear bandit problems and two-layer neural network with polynomial activation bandit problem. For the…
▽ More
Bandit problems with linear or concave reward have been extensively studied, but relatively few works have studied bandits with non-concave reward. This work considers a large family of bandit problems where the unknown underlying reward function is non-concave, including the low-rank generalized linear bandit problems and two-layer neural network with polynomial activation bandit problem. For the low-rank generalized linear bandit problem, we provide a minimax-optimal algorithm in the dimension, refuting both conjectures in [LMT21, JWWN19]. Our algorithms are based on a unified zeroth-order optimization paradigm that applies in great generality and attains optimal rates in several structured polynomial settings (in the dimension). We further demonstrate the applicability of our algorithms in RL in the generative model setting, resulting in improved sample complexity over prior approaches. Finally, we show that the standard optimistic algorithms (e.g., UCB) are sub-optimal by dimension factors. In the neural net setting (with polynomial activation functions) with noiseless reward, we provide a bandit algorithm with sample complexity equal to the intrinsic algebraic dimension. Again, we show that optimistic approaches have worse sample complexity, polynomial in the extrinsic dimension (which could be exponentially worse in the polynomial degree).
△ Less
Submitted 9 July, 2021;
originally announced July 2021.
-
A Short Note on the Relationship of Information Gain and Eluder Dimension
Authors:
Kaixuan Huang,
Sham M. Kakade,
Jason D. Lee,
Qi Lei
Abstract:
Eluder dimension and information gain are two widely used methods of complexity measures in bandit and reinforcement learning. Eluder dimension was originally proposed as a general complexity measure of function classes, but the common examples of where it is known to be small are function spaces (vector spaces). In these cases, the primary tool to upper bound the eluder dimension is the elliptic…
▽ More
Eluder dimension and information gain are two widely used methods of complexity measures in bandit and reinforcement learning. Eluder dimension was originally proposed as a general complexity measure of function classes, but the common examples of where it is known to be small are function spaces (vector spaces). In these cases, the primary tool to upper bound the eluder dimension is the elliptic potential lemma. Interestingly, the elliptic potential lemma also features prominently in the analysis of linear bandits/reinforcement learning and their nonparametric generalization, the information gain. We show that this is not a coincidence -- eluder dimension and information gain are equivalent in a precise sense for reproducing kernel Hilbert spaces.
△ Less
Submitted 6 July, 2021;
originally announced July 2021.
-
Near-Optimal Linear Regression under Distribution Shift
Authors:
Qi Lei,
Wei Hu,
Jason D. Lee
Abstract:
Transfer learning is essential when sufficient data comes from the source domain, with scarce labeled data from the target domain. We develop estimators that achieve minimax linear risk for linear regression problems under distribution shift. Our algorithms cover different transfer learning settings including covariate shift and model shift. We also consider when data are generated from either lin…
▽ More
Transfer learning is essential when sufficient data comes from the source domain, with scarce labeled data from the target domain. We develop estimators that achieve minimax linear risk for linear regression problems under distribution shift. Our algorithms cover different transfer learning settings including covariate shift and model shift. We also consider when data are generated from either linear or general nonlinear models. We show that linear minimax estimators are within an absolute constant of the minimax risk even among nonlinear estimators for various source/target distributions.
△ Less
Submitted 22 June, 2021;
originally announced June 2021.
-
Label Noise SGD Provably Prefers Flat Global Minimizers
Authors:
Alex Damian,
Tengyu Ma,
Jason D. Lee
Abstract:
In overparametrized models, the noise in stochastic gradient descent (SGD) implicitly regularizes the optimization trajectory and determines which local minimum SGD converges to. Motivated by empirical studies that demonstrate that training with noisy labels improves generalization, we study the implicit regularization effect of SGD with label noise. We show that SGD with label noise converges to…
▽ More
In overparametrized models, the noise in stochastic gradient descent (SGD) implicitly regularizes the optimization trajectory and determines which local minimum SGD converges to. Motivated by empirical studies that demonstrate that training with noisy labels improves generalization, we study the implicit regularization effect of SGD with label noise. We show that SGD with label noise converges to a stationary point of a regularized loss $L(θ) +λR(θ)$, where $L(θ)$ is the training loss, $λ$ is an effective regularization parameter depending on the step size, strength of the label noise, and the batch size, and $R(θ)$ is an explicit regularizer that penalizes sharp minimizers. Our analysis uncovers an additional regularization effect of large learning rates beyond the linear scaling rule that penalizes large eigenvalues of the Hessian more than small ones. We also prove extensions to classification with general loss functions, SGD with momentum, and SGD with general noise covariance, significantly strengthening the prior work of Blanc et al. to global convergence and large learning rates and of HaoChen et al. to general models.
△ Less
Submitted 4 December, 2021; v1 submitted 11 June, 2021;
originally announced June 2021.
-
Policy Mirror Descent for Regularized Reinforcement Learning: A Generalized Framework with Linear Convergence
Authors:
Wenhao Zhan,
Shicong Cen,
Baihe Huang,
Yuxin Chen,
Jason D. Lee,
Yuejie Chi
Abstract:
Policy optimization, which finds the desired policy by maximizing value functions via optimization techniques, lies at the heart of reinforcement learning (RL). In addition to value maximization, other practical considerations arise as well, including the need of encouraging exploration, and that of ensuring certain structural properties of the learned policy due to safety, resource and operationa…
▽ More
Policy optimization, which finds the desired policy by maximizing value functions via optimization techniques, lies at the heart of reinforcement learning (RL). In addition to value maximization, other practical considerations arise as well, including the need of encouraging exploration, and that of ensuring certain structural properties of the learned policy due to safety, resource and operational constraints. These can often be accounted for via regularized RL, which augments the target value function with a structure-promoting regularizer.
Focusing on discounted infinite-horizon Markov decision processes, we propose a generalized policy mirror descent (GPMD) algorithm for solving regularized RL. As a generalization of policy mirror descent (arXiv:2102.00135), our algorithm accommodates a general class of convex regularizers and promotes the use of Bregman divergence in cognizant of the regularizer in use. We demonstrate that our algorithm converges linearly to the global solution over an entire range of learning rates, in a dimension-free fashion, even when the regularizer lacks strong convexity and smoothness. In addition, this linear convergence feature is provably stable in the face of inexact policy evaluation and imperfect policy updates. Numerical experiments are provided to corroborate the appealing performance of GPMD.
△ Less
Submitted 10 January, 2023; v1 submitted 23 May, 2021;
originally announced May 2021.
-
How Fine-Tuning Allows for Effective Meta-Learning
Authors:
Kurtland Chua,
Qi Lei,
Jason D. Lee
Abstract:
Representation learning has been widely studied in the context of meta-learning, enabling rapid learning of new tasks through shared representations. Recent works such as MAML have explored using fine-tuning-based metrics, which measure the ease by which fine-tuning can achieve good performance, as proxies for obtaining representations. We present a theoretical framework for analyzing representati…
▽ More
Representation learning has been widely studied in the context of meta-learning, enabling rapid learning of new tasks through shared representations. Recent works such as MAML have explored using fine-tuning-based metrics, which measure the ease by which fine-tuning can achieve good performance, as proxies for obtaining representations. We present a theoretical framework for analyzing representations derived from a MAML-like algorithm, assuming the available tasks use approximately the same underlying representation. We then provide risk bounds on the best predictor found by fine-tuning via gradient descent, demonstrating that the algorithm can provably leverage the shared structure. The upper bound applies to general function classes, which we demonstrate by instantiating the guarantees of our framework in the logistic regression and neural network settings. In contrast, we establish the existence of settings where any algorithm, using a representation trained with no consideration for task-specific fine-tuning, performs as well as a learner with no access to source tasks in the worst case. This separation result underscores the benefit of fine-tuning-based methods, such as MAML, over methods with "frozen representation" objectives in few-shot learning.
△ Less
Submitted 5 May, 2021;
originally announced May 2021.