Fast unsupervised ground metric learning with tree-Wasserstein distance
Authors:
Kira M. Düsterwald,
Samo Hromadka,
Makoto Yamada
Abstract:
The performance of unsupervised methods such as clustering depends on the choice of distance metric between features, or ground metric. Commonly, ground metrics are decided with heuristics or learned via supervised algorithms. However, since many interesting datasets are unlabelled, unsupervised ground metric learning approaches have been introduced. One promising option employs Wasserstein singul…
▽ More
The performance of unsupervised methods such as clustering depends on the choice of distance metric between features, or ground metric. Commonly, ground metrics are decided with heuristics or learned via supervised algorithms. However, since many interesting datasets are unlabelled, unsupervised ground metric learning approaches have been introduced. One promising option employs Wasserstein singular vectors (WSVs), which emerge when computing optimal transport distances between features and samples simultaneously. WSVs are effective, but can be prohibitively computationally expensive in some applications: $\mathcal{O}(n^2m^2(n \log(n) + m \log(m))$ for $n$ samples and $m$ features. In this work, we propose to augment the WSV method by embedding samples and features on trees, on which we compute the tree-Wasserstein distance (TWD). We demonstrate theoretically and empirically that the algorithm converges to a better approximation of the standard WSV approach than the best known alternatives, and does so with $\mathcal{O}(n^3+m^3+mn)$ complexity. In addition, we prove that the initial tree structure can be chosen flexibly, since tree geometry does not constrain the richness of the approximation up to the number of edge weights. This proof suggests a fast and recursive algorithm for computing the tree parameter basis set, which we find crucial to realising the efficiency gains at scale. Finally, we employ the tree-WSV algorithm to several single-cell RNA sequencing genomics datasets, demonstrating its scalability and utility for unsupervised cell-type clustering problems. These results poise unsupervised ground metric learning with TWD as a low-rank approximation of WSV with the potential for widespread application.
△ Less
Submitted 10 January, 2025; v1 submitted 11 November, 2024;
originally announced November 2024.
An Empirical Study of Self-supervised Learning with Wasserstein Distance
Authors:
Makoto Yamada,
Yuki Takezawa,
Guillaume Houry,
Kira Michaela Dusterwald,
Deborah Sulem,
Han Zhao,
Yao-Hung Hubert Tsai
Abstract:
In this study, we delve into the problem of self-supervised learning (SSL) utilizing the 1-Wasserstein distance on a tree structure (a.k.a., Tree-Wasserstein distance (TWD)), where TWD is defined as the L1 distance between two tree-embedded vectors. In SSL methods, the cosine similarity is often utilized as an objective function; however, it has not been well studied when utilizing the Wasserstein…
▽ More
In this study, we delve into the problem of self-supervised learning (SSL) utilizing the 1-Wasserstein distance on a tree structure (a.k.a., Tree-Wasserstein distance (TWD)), where TWD is defined as the L1 distance between two tree-embedded vectors. In SSL methods, the cosine similarity is often utilized as an objective function; however, it has not been well studied when utilizing the Wasserstein distance. Training the Wasserstein distance is numerically challenging. Thus, this study empirically investigates a strategy for optimizing the SSL with the Wasserstein distance and finds a stable training procedure. More specifically, we evaluate the combination of two types of TWD (total variation and ClusterTree) and several probability models, including the softmax function, the ArcFace probability model, and simplicial embedding. We propose a simple yet effective Jeffrey divergence-based regularization method to stabilize optimization. Through empirical experiments on STL10, CIFAR10, CIFAR100, and SVHN, we find that a simple combination of the softmax function and TWD can obtain significantly lower results than the standard SimCLR. Moreover, a simple combination of TWD and SimSiam fails to train the model. We find that the model performance depends on the combination of TWD and probability model, and that the Jeffrey divergence regularization helps in model training. Finally, we show that the appropriate combination of the TWD and probability model outperforms cosine similarity-based representation learning.
△ Less
Submitted 5 February, 2024; v1 submitted 16 October, 2023;
originally announced October 2023.