-
TADA: Improved Diffusion Sampling with Training-free Augmented Dynamics
Authors:
Tianrong Chen,
Huangjie Zheng,
David Berthelot,
Jiatao Gu,
Josh Susskind,
Shuangfei Zhai
Abstract:
Diffusion models have demonstrated exceptional capabilities in generating high-fidelity images but typically suffer from inefficient sampling. Many solver designs and noise scheduling strategies have been proposed to dramatically improve sampling speeds. In this paper, we introduce a new sampling method that is up to $186\%$ faster than the current state of the art solver for comparative FID on Im…
▽ More
Diffusion models have demonstrated exceptional capabilities in generating high-fidelity images but typically suffer from inefficient sampling. Many solver designs and noise scheduling strategies have been proposed to dramatically improve sampling speeds. In this paper, we introduce a new sampling method that is up to $186\%$ faster than the current state of the art solver for comparative FID on ImageNet512. This new sampling method is training-free and uses an ordinary differential equation (ODE) solver. The key to our method resides in using higher-dimensional initial noise, allowing to produce more detailed samples with less function evaluations from existing pretrained diffusion models. In addition, by design our solver allows to control the level of detail through a simple hyper-parameter at no extra computational cost. We present how our approach leverages momentum dynamics by establishing a fundamental equivalence between momentum diffusion models and conventional diffusion models with respect to their training paradigms. Moreover, we observe the use of higher-dimensional noise naturally exhibits characteristics similar to stochastic differential equations (SDEs). Finally, we demonstrate strong performances on a set of representative pretrained diffusion models, including EDM, EDM2, and Stable-Diffusion 3, which cover models in both pixel and latent spaces, as well as class and text conditional settings. The code is available at https://github.com/apple/ml-tada.
△ Less
Submitted 26 June, 2025;
originally announced June 2025.
-
Composition and Control with Distilled Energy Diffusion Models and Sequential Monte Carlo
Authors:
James Thornton,
Louis Bethune,
Ruixiang Zhang,
Arwen Bradley,
Preetum Nakkiran,
Shuangfei Zhai
Abstract:
Diffusion models may be formulated as a time-indexed sequence of energy-based models, where the score corresponds to the negative gradient of an energy function. As opposed to learning the score directly, an energy parameterization is attractive as the energy itself can be used to control generation via Monte Carlo samplers. Architectural constraints and training instability in energy parameterize…
▽ More
Diffusion models may be formulated as a time-indexed sequence of energy-based models, where the score corresponds to the negative gradient of an energy function. As opposed to learning the score directly, an energy parameterization is attractive as the energy itself can be used to control generation via Monte Carlo samplers. Architectural constraints and training instability in energy parameterized models have so far yielded inferior performance compared to directly approximating the score or denoiser. We address these deficiencies by introducing a novel training regime for the energy function through distillation of pre-trained diffusion models, resembling a Helmholtz decomposition of the score vector field. We further showcase the synergies between energy and score by casting the diffusion sampling procedure as a Feynman Kac model where sampling is controlled using potentials from the learnt energy functions. The Feynman Kac model formalism enables composition and low temperature sampling through sequential Monte Carlo.
△ Less
Submitted 18 February, 2025;
originally announced February 2025.
-
Deep Generative Quantile Bayes
Authors:
Jungeum Kim,
Percy S. Zhai,
Veronika Ročková
Abstract:
We develop a multivariate posterior sampling procedure through deep generative quantile learning. Simulation proceeds implicitly through a push-forward mapping that can transform i.i.d. random vector samples from the posterior. We utilize Monge-Kantorovich depth in multivariate quantiles to directly sample from Bayesian credible sets, a unique feature not offered by typical posterior sampling meth…
▽ More
We develop a multivariate posterior sampling procedure through deep generative quantile learning. Simulation proceeds implicitly through a push-forward mapping that can transform i.i.d. random vector samples from the posterior. We utilize Monge-Kantorovich depth in multivariate quantiles to directly sample from Bayesian credible sets, a unique feature not offered by typical posterior sampling methods. To enhance the training of the quantile mapping, we design a neural network that automatically performs summary statistic extraction. This additional neural network structure has performance benefits, including support shrinkage (i.e., contraction of our posterior approximation) as the observation sample size increases. We demonstrate the usefulness of our approach on several examples where the absence of likelihood renders classical MCMC infeasible. Finally, we provide the following frequentist theoretical justifications for our quantile learning framework: {consistency of the estimated vector quantile, of the recovered posterior distribution, and of the corresponding Bayesian credible sets.
△ Less
Submitted 29 April, 2025; v1 submitted 10 October, 2024;
originally announced October 2024.
-
Improving GFlowNets for Text-to-Image Diffusion Alignment
Authors:
Dinghuai Zhang,
Yizhe Zhang,
Jiatao Gu,
Ruixiang Zhang,
Josh Susskind,
Navdeep Jaitly,
Shuangfei Zhai
Abstract:
Diffusion models have become the de-facto approach for generating visual data, which are trained to match the distribution of the training dataset. In addition, we also want to control generation to fulfill desired properties such as alignment to a text description, which can be specified with a black-box reward function. Prior works fine-tune pretrained diffusion models to achieve this goal throu…
▽ More
Diffusion models have become the de-facto approach for generating visual data, which are trained to match the distribution of the training dataset. In addition, we also want to control generation to fulfill desired properties such as alignment to a text description, which can be specified with a black-box reward function. Prior works fine-tune pretrained diffusion models to achieve this goal through reinforcement learning-based algorithms. Nonetheless, they suffer from issues including slow credit assignment as well as low quality in their generated samples. In this work, we explore techniques that do not directly maximize the reward but rather generate high-reward images with relatively high probability -- a natural scenario for the framework of generative flow networks (GFlowNets). To this end, we propose the Diffusion Alignment with GFlowNet (DAG) algorithm to post-train diffusion models with black-box property functions. Extensive experiments on Stable Diffusion and various reward specifications corroborate that our method could effectively align large-scale text-to-image diffusion models with given reward information.
△ Less
Submitted 25 December, 2024; v1 submitted 2 June, 2024;
originally announced June 2024.
-
Stabilizing Transformer Training by Preventing Attention Entropy Collapse
Authors:
Shuangfei Zhai,
Tatiana Likhomanenko,
Etai Littwin,
Dan Busbridge,
Jason Ramapuram,
Yizhe Zhang,
Jiatao Gu,
Josh Susskind
Abstract:
Training stability is of great importance to Transformers. In this work, we investigate the training dynamics of Transformers by examining the evolution of the attention layers. In particular, we track the attention entropy for each attention head during the course of training, which is a proxy for model sharpness. We identify a common pattern across different architectures and tasks, where low at…
▽ More
Training stability is of great importance to Transformers. In this work, we investigate the training dynamics of Transformers by examining the evolution of the attention layers. In particular, we track the attention entropy for each attention head during the course of training, which is a proxy for model sharpness. We identify a common pattern across different architectures and tasks, where low attention entropy is accompanied by high training instability, which can take the form of oscillating loss or divergence. We denote the pathologically low attention entropy, corresponding to highly concentrated attention scores, as $\textit{entropy collapse}$. As a remedy, we propose $σ$Reparam, a simple and efficient solution where we reparametrize all linear layers with spectral normalization and an additional learned scalar. We demonstrate that $σ$Reparam successfully prevents entropy collapse in the attention layers, promoting more stable training. Additionally, we prove a tight lower bound of the attention entropy, which decreases exponentially fast with the spectral norm of the attention logits, providing additional motivation for our approach. We conduct experiments with $σ$Reparam on image classification, image self-supervised learning, machine translation, speech recognition, and language modeling tasks. We show that $σ$Reparam provides stability and robustness with respect to the choice of hyperparameters, going so far as enabling training (a) a Vision Transformer {to competitive performance} without warmup, weight decay, layer normalization or adaptive optimizers; (b) deep architectures in machine translation and (c) speech recognition to competitive performance without warmup and adaptive optimizers. Code is available at \url{https://github.com/apple/ml-sigma-reparam}.
△ Less
Submitted 25 July, 2023; v1 submitted 10 March, 2023;
originally announced March 2023.
-
High-dimensional Functional Graphical Model Structure Learning via Neighborhood Selection Approach
Authors:
Boxin Zhao,
Percy S. Zhai,
Y. Samuel Wang,
Mladen Kolar
Abstract:
Undirected graphical models are widely used to model the conditional independence structure of vector-valued data. However, in many modern applications, for example those involving EEG and fMRI data, observations are more appropriately modeled as multivariate random functions rather than vectors. Functional graphical models have been proposed to model the conditional independence structure of such…
▽ More
Undirected graphical models are widely used to model the conditional independence structure of vector-valued data. However, in many modern applications, for example those involving EEG and fMRI data, observations are more appropriately modeled as multivariate random functions rather than vectors. Functional graphical models have been proposed to model the conditional independence structure of such functional data. We propose a neighborhood selection approach to estimate the structure of Gaussian functional graphical models, where we first estimate the neighborhood of each node via a function-on-function regression and subsequently recover the entire graph structure by combining the estimated neighborhoods. Our approach only requires assumptions on the conditional distributions of random functions, and we estimate the conditional independence structure directly. We thus circumvent the need for a well-defined precision operator that may not exist when the functions are infinite dimensional. Additionally, the neighborhood selection approach is computationally efficient and can be easily parallelized. The statistical consistency of the proposed method in the high-dimensional setting is supported by both theory and experimental results. In addition, we study the effect of the choice of the function basis used for dimensionality reduction in an intermediate step. We give a heuristic criterion for choosing a function basis and motivate two practically useful choices, which we justify by both theory and experiments.
△ Less
Submitted 25 January, 2024; v1 submitted 6 May, 2021;
originally announced May 2021.
-
Set Distribution Networks: a Generative Model for Sets of Images
Authors:
Shuangfei Zhai,
Walter Talbott,
Miguel Angel Bautista,
Carlos Guestrin,
Josh M. Susskind
Abstract:
Images with shared characteristics naturally form sets. For example, in a face verification benchmark, images of the same identity form sets. For generative models, the standard way of dealing with sets is to represent each as a one hot vector, and learn a conditional generative model $p(\mathbf{x}|\mathbf{y})$. This representation assumes that the number of sets is limited and known, such that th…
▽ More
Images with shared characteristics naturally form sets. For example, in a face verification benchmark, images of the same identity form sets. For generative models, the standard way of dealing with sets is to represent each as a one hot vector, and learn a conditional generative model $p(\mathbf{x}|\mathbf{y})$. This representation assumes that the number of sets is limited and known, such that the distribution over sets reduces to a simple multinomial distribution. In contrast, we study a more generic problem where the number of sets is large and unknown. We introduce Set Distribution Networks (SDNs), a novel framework that learns to autoencode and freely generate sets. We achieve this by jointly learning a set encoder, set discriminator, set generator, and set prior. We show that SDNs are able to reconstruct image sets that preserve salient attributes of the inputs in our benchmark datasets, and are also able to generate novel objects/identities. We examine the sets generated by SDN with a pre-trained 3D reconstruction network and a face verification network, respectively, as a novel way to evaluate the quality of generated sets of images.
△ Less
Submitted 18 June, 2020;
originally announced June 2020.
-
Collegial Ensembles
Authors:
Etai Littwin,
Ben Myara,
Sima Sabah,
Joshua Susskind,
Shuangfei Zhai,
Oren Golan
Abstract:
Modern neural network performance typically improves as model size increases. A recent line of research on the Neural Tangent Kernel (NTK) of over-parameterized networks indicates that the improvement with size increase is a product of a better conditioned loss landscape. In this work, we investigate a form of over-parameterization achieved through ensembling, where we define collegial ensembles (…
▽ More
Modern neural network performance typically improves as model size increases. A recent line of research on the Neural Tangent Kernel (NTK) of over-parameterized networks indicates that the improvement with size increase is a product of a better conditioned loss landscape. In this work, we investigate a form of over-parameterization achieved through ensembling, where we define collegial ensembles (CE) as the aggregation of multiple independent models with identical architectures, trained as a single model. We show that the optimization dynamics of CE simplify dramatically when the number of models in the ensemble is large, resembling the dynamics of wide models, yet scale much more favorably. We use recent theoretical results on the finite width corrections of the NTK to perform efficient architecture search in a space of finite width CE that aims to either minimize capacity, or maximize trainability under a set of constraints. The resulting ensembles can be efficiently implemented in practical architectures using group convolutions and block diagonal layers. Finally, we show how our framework can be used to analytically derive optimal group convolution modules originally found using expensive grid searches, without having to train a single model.
△ Less
Submitted 17 June, 2020; v1 submitted 13 June, 2020;
originally announced June 2020.
-
Adversarial Fisher Vectors for Unsupervised Representation Learning
Authors:
Shuangfei Zhai,
Walter Talbott,
Carlos Guestrin,
Joshua M. Susskind
Abstract:
We examine Generative Adversarial Networks (GANs) through the lens of deep Energy Based Models (EBMs), with the goal of exploiting the density model that follows from this formulation. In contrast to a traditional view where the discriminator learns a constant function when reaching convergence, here we show that it can provide useful information for downstream tasks, e.g., feature extraction for…
▽ More
We examine Generative Adversarial Networks (GANs) through the lens of deep Energy Based Models (EBMs), with the goal of exploiting the density model that follows from this formulation. In contrast to a traditional view where the discriminator learns a constant function when reaching convergence, here we show that it can provide useful information for downstream tasks, e.g., feature extraction for classification. To be concrete, in the EBM formulation, the discriminator learns an unnormalized density function (i.e., the negative energy term) that characterizes the data manifold. We propose to evaluate both the generator and the discriminator by deriving corresponding Fisher Score and Fisher Information from the EBM. We show that by assuming that the generated examples form an estimate of the learned density, both the Fisher Information and the normalized Fisher Vectors are easy to compute. We also show that we are able to derive a distance metric between examples and between sets of examples. We conduct experiments showing that the GAN-induced Fisher Vectors demonstrate competitive performance as unsupervised feature extractors for classification and perceptual similarity tasks. Code is available at \url{https://github.com/apple/ml-afv}.
△ Less
Submitted 29 October, 2019;
originally announced October 2019.
-
Addressing the Loss-Metric Mismatch with Adaptive Loss Alignment
Authors:
Chen Huang,
Shuangfei Zhai,
Walter Talbott,
Miguel Angel Bautista,
Shih-Yu Sun,
Carlos Guestrin,
Josh Susskind
Abstract:
In most machine learning training paradigms a fixed, often handcrafted, loss function is assumed to be a good proxy for an underlying evaluation metric. In this work we assess this assumption by meta-learning an adaptive loss function to directly optimize the evaluation metric. We propose a sample efficient reinforcement learning approach for adapting the loss dynamically during training. We empir…
▽ More
In most machine learning training paradigms a fixed, often handcrafted, loss function is assumed to be a good proxy for an underlying evaluation metric. In this work we assess this assumption by meta-learning an adaptive loss function to directly optimize the evaluation metric. We propose a sample efficient reinforcement learning approach for adapting the loss dynamically during training. We empirically show how this formulation improves performance by simultaneously optimizing the evaluation metric and smoothing the loss landscape. We verify our method in metric learning and classification scenarios, showing considerable improvements over the state-of-the-art on a diverse set of tasks. Importantly, our method is applicable to a wide range of loss functions and evaluation metrics. Furthermore, the learned policies are transferable across tasks and data, demonstrating the versatility of the method.
△ Less
Submitted 14 May, 2019;
originally announced May 2019.
-
Boosting Deep Learning Risk Prediction with Generative Adversarial Networks for Electronic Health Records
Authors:
Zhengping Che,
Yu Cheng,
Shuangfei Zhai,
Zhaonan Sun,
Yan Liu
Abstract:
The rapid growth of Electronic Health Records (EHRs), as well as the accompanied opportunities in Data-Driven Healthcare (DDH), has been attracting widespread interests and attentions. Recent progress in the design and applications of deep learning methods has shown promising results and is forcing massive changes in healthcare academia and industry, but most of these methods rely on massive label…
▽ More
The rapid growth of Electronic Health Records (EHRs), as well as the accompanied opportunities in Data-Driven Healthcare (DDH), has been attracting widespread interests and attentions. Recent progress in the design and applications of deep learning methods has shown promising results and is forcing massive changes in healthcare academia and industry, but most of these methods rely on massive labeled data. In this work, we propose a general deep learning framework which is able to boost risk prediction performance with limited EHR data. Our model takes a modified generative adversarial network namely ehrGAN, which can provide plausible labeled EHR data by mimicking real patient records, to augment the training dataset in a semi-supervised learning manner. We use this generative model together with a convolutional neural network (CNN) based prediction model to improve the onset prediction performance. Experiments on two real healthcare datasets demonstrate that our proposed framework produces realistic data samples and achieves significant improvements on classification tasks with the generated data over several stat-of-the-art baselines.
△ Less
Submitted 5 September, 2017;
originally announced September 2017.
-
Structural Correspondence Learning for Cross-lingual Sentiment Classification with One-to-many Mappings
Authors:
Nana Li,
Shuangfei Zhai,
Zhongfei Zhang,
Boying Liu
Abstract:
Structural correspondence learning (SCL) is an effective method for cross-lingual sentiment classification. This approach uses unlabeled documents along with a word translation oracle to automatically induce task specific, cross-lingual correspondences. It transfers knowledge through identifying important features, i.e., pivot features. For simplicity, however, it assumes that the word translation…
▽ More
Structural correspondence learning (SCL) is an effective method for cross-lingual sentiment classification. This approach uses unlabeled documents along with a word translation oracle to automatically induce task specific, cross-lingual correspondences. It transfers knowledge through identifying important features, i.e., pivot features. For simplicity, however, it assumes that the word translation oracle maps each pivot feature in source language to exactly only one word in target language. This one-to-one mapping between words in different languages is too strict. Also the context is not considered at all. In this paper, we propose a cross-lingual SCL based on distributed representation of words; it can learn meaningful one-to-many mappings for pivot words using large amounts of monolingual data and a small dictionary. We conduct experiments on NLP\&CC 2013 cross-lingual sentiment analysis dataset, employing English as source language, and Chinese as target language. Our method does not rely on the parallel corpora and the experimental results show that our approach is more competitive than the state-of-the-art methods in cross-lingual sentiment classification.
△ Less
Submitted 26 November, 2016;
originally announced November 2016.
-
Deep Structured Energy Based Models for Anomaly Detection
Authors:
Shuangfei Zhai,
Yu Cheng,
Weining Lu,
Zhongfei Zhang
Abstract:
In this paper, we attack the anomaly detection problem by directly modeling the data distribution with deep architectures. We propose deep structured energy based models (DSEBMs), where the energy function is the output of a deterministic deep neural network with structure. We develop novel model architectures to integrate EBMs with different types of data such as static data, sequential data, and…
▽ More
In this paper, we attack the anomaly detection problem by directly modeling the data distribution with deep architectures. We propose deep structured energy based models (DSEBMs), where the energy function is the output of a deterministic deep neural network with structure. We develop novel model architectures to integrate EBMs with different types of data such as static data, sequential data, and spatial data, and apply appropriate model architectures to adapt to the data structure. Our training algorithm is built upon the recent development of score matching \cite{sm}, which connects an EBM with a regularized autoencoder, eliminating the need for complicated sampling method. Statistically sound decision criterion can be derived for anomaly detection purpose from the perspective of the energy landscape of the data distribution. We investigate two decision criteria for performing anomaly detection: the energy score and the reconstruction error. Extensive empirical studies on benchmark tasks demonstrate that our proposed model consistently matches or outperforms all the competing methods.
△ Less
Submitted 15 June, 2016; v1 submitted 24 May, 2016;
originally announced May 2016.