Research overview

Matrix Recovery and Hessian-based Measurements of Neural Networks

Modern neural networks are typically designed with far more parameters than training examples, a regime known as over-parameterization. Classical generalization theory, e.g., from the VC dimension or Rademacher complexity perspectives, often yields vacuous bounds that do not explain why gradient descent finds solutions that generalize. From an optimization perspective, while it is empirically observed that gradient descent converges to flat, low-rank solutions in this regime, the theoretical understanding of this implicit bias has been limited beyond simplified linear models. Further, existing generalization bounds for architectures such as graph neural networks are scaled with coarse graph statistics (e.g., maximum degree), giving loose bounds for real-world graphs.

My contribution in this area involves studying generalization and optimization from a second-order perspective. Particularly, we develop an algorithmic framework to measure the Hessian spectral statistics, and apply this framework to large models. We start by studying implicit regularization of gradient descent in over-parameterized matrix sensing, starting from small initializations. Then, we consider supervised fine-tuning, and show that Hessian-based generalization bounds are non-vacuous, and provide several downstream implications on regularization and quantization. Further, we apply this technique to show new generalization bounds for graph neural networks.

Matrix recovery

In a COLT’18 paper, we consider gradient descent dynamics in over-parameterized matrix sensing, where we are given linear measurements of an unknown matrix M. We provide a detailed convergence analysis, starting from a small initialization, for recovering M when the number of parameters exceeds the number of training data points. A key insight is that the gradient descent iterates remain close to a low-rank subspace, ultimately converging to the minimum-nuclear-norm solution among all interpolating solutions.

Building on these techniques, we consider matrix completion in the ultra-sparse sampling regime: each entry of the unknown n×d matrix M is observed with probability p = C/d for some fixed constant C. Cao, Liang, and Valiant established the open problem of whether it is possible to recover one side, or the second-moment matrix M⊤M/n accurately, in this ultra-sparse sampling regime. Our paper resolves it in the affirmative. A key technique is self-normalization, commonly known as the Hájek estimator. We show that this estimator is unbiased for the second-moment matrix and, moreover, reduces variance, yielding more accurate estimates in practice.

Moving from matrix recovery to neural networks, I have considered data augmentation, a practical method for enhancing generalization. In joint work with Greg Valiant and Chris Ré at ICML’20, we measure the bias and variance of both invariant and mixup transformations, and formalize our findings on augmentations in the over-parameterized linear regression setting.

Hessian-based measurements for fine-tuning

Next, I consider supervised fine-tuning, in which gradient descent begins from a pretrained model rather than a random initialization, introducing new challenges that the above results alone cannot address. In a line of work with my students Haotian Ju and Dongyue Li, we formally show generalization bounds for fine-tuning. We assume that the fine-tuned model lies within a certain radius from the pretrained model, and show that the generalization gap can be upper-bounded by the (maximum) trace of the Hessian of the loss function over the entire data distribution, times the radius squared. A particularly interesting property of the Hessian-based measurements is that they are non-vacuous, meaning that they can match the scale of empirically observed generalization gaps. This is crucial because the Hessian-based framework can yield meaningful measurements for downstream applications.

Here are several algorithmic implications from this computational framework:

  • Noise Stability Optimization for Finding Flat Minima: We design a noise-injection algorithm with a regularization effect on the Hessian trace of the loss surface and empirically validate this algorithm in a variety of practical settings.

  • Accelerating Quantization-Aware Training of Language Models Around Saddle Points (with Dongyue Li et al.), International Conference on Machine Learning, 2026: We discover that noise-injection can also help accelerate quantization-aware training, which exhibits slow convergence due to saddle points with a large fraction of both positive and negative eigenvalues in the loss Hessian.

  • In preliminary work, we find empirical evidence that regularizating the Hessian trace also mitigates grokking in modular arithmetic tasks.

Improved generalization bounds for graph neural networks

The Hessian trace bound extends naturally to graph-structured data, where we also uncover a connection between this loss curvature and the spectral properties of the graph diffusion matrix. This connection allows us to improve the state-of-the-art generalization bound in graph neural networks. Previous work has shown generalization bounds for graph neural networks that scale with the graph structure, specifically the maximum degree of all vertices. We show a generalization bound that instead scales with the largest singular value of the graph diffusion matrix. For example, consider a node classification problem where each node represents a user or a video on a social network, and the goal is to predict each node's label for a recommendation system. In graph convolutional networks, the largest singular value of the normalized graph Laplacian is at most one. These bounds are numerically much smaller than prior bounds for real-world graphs.

Multitask Learning and Foundation Models: Theoretical Understanding and New Algorithms

The problem of multitask learning is as follows: Given several related tasks, how can we train a neural network to make accurate predictions on all of them simultaneously? This line of work is influenced by the development of foundation models, which are often trained on diverse datasets. When different tasks are trained in a network with shared parameters, how does information from one task transfer to another task? Relatedly, how can we identify the most helpful tasks to train alongside another downstream task? My contribution includes (i) theoretical understanding of information transfer in multitask networks, (ii) practical algorithms for task/data partitioning and selection, and (iii) second-order analysis of task attribution that rigorously connects influence functions to linear surrogate modeling.

Theoretical understanding of information transfer in multitask learning

In an ICLR’20 paper with Chris Ré, we formally study transfer by relating multi-headed neural networks — a common architecture for conducting multitask learning — to two-layer neural networks. Our paper represents one of the first in-depth analyses of negative transfer in two-layer neural networks. With this connection, questions regarding how one task affects another, etc., become amenable to statistical analysis. In a JMLR’25 paper, we improve on the initial result and provide a precise quantification of transfer in high-dimensional linear regression. We formulate hard-parameter sharing estimation for two linear regression tasks in the high-dimensional, proportional regime. Compared with single-task learning, we show a phase transition from positive transfer to negative transfer as the number of source-task samples increases.

Clustering algorithms for multitask learning

A key insight from this work is that negative transfer becomes inherent when tasks have severe distribution shifts. To scale these insights to foundation models, we develop a surrogate modeling approach to predict the performance of learning multiple tasks simultaneously. In a series of papers with my student Dongyue Li, in collaboration with Aneesh Sharma and Lu Wang, we design convex relaxation algorithms to find approximate partitioning of tasks in multitask learning. The optimization program is based on an affinity matrix that captures task relationships and is estimated using a surrogate modeling approach. This yields a random-forest-style algorithm that captures higher-order correlations among tasks more accurately than existing methods, and can be efficiently implemented on top of foundation models using a linearization technique, reminiscent of neural tangent kernels. We have extensively validated this approach in a variety of downstream applications, including overlapping community detection, ensemble low-rank adaptation of language models, and demonstration selection for in-context learning.

Multi-objective reinforcement learning

A related problem that is amenable to the above techniques is multi-objective reinforcement learning, which involves balancing multiple conflicting objectives in RL. This problem has broad applications in modern AI, such as in alignment and in robotics. In an AAAI’26 paper, we apply the above approach on top of proximal policy gradient to partition similar trajectories into groups. A key insight is to design a routing mechanism that directs a trajectory to the partition that yields the highest reward.

Understanding linear surrogate modeling via a Hessian analysis

Modern AI models are trained on diverse tasks, leading to the question of quantifying the influence of individual tasks upon a model, a problem we refer to as task attribution. This problem is closely related to the local geometry of loss landscapes, which can be captured by the Hessian matrix of the loss function. Prior work has captured this connection through influence functions and Hessian eigenmaps. In an ICLR’26 paper with my students Zhenshuo Zhang and Minxuan Duan, we rigorously connect influence functions to linear surrogate modeling, an empirical procedure practitioners often use to attribute the influence of data on the trained model.

Learning and Algorithmic Reasoning on Large-Scale Networks

My work in this area spans learning and algorithmic reasoning on large-scale network data. This direction provides both a domain for testing the preceding ideas and a rich source of connections to classical algorithms and tools in spectral graph theory.

Dynamic PageRank computation

Earlier contributions include the first running-time analysis of local push (a deterministic algorithm for computing personalized PageRank) on dynamic graphs, in which we propose and analyze natural dynamic versions of known local variations of power methods for solving linear systems. Another contribution is a spectral algorithm for reducing epidemic diffusion on weighted graphs by minimizing the sum of the k largest eigenvalues, generalizing earlier work in this literature that reduces the top eigenvalue.

Traffic accident prediction

We have also contributed to this space through collecting new graph datasets (with my students Ziniu Zhang and Abhinav Nippani). We collected traffic accident records from the Department of Transportation websites and constructed a large-scale graph dataset representing road networks in eight states in the US, including Massachusetts. Using this dataset, we study traffic accident prediction on road networks using graph neural networks (GNNs). We have also collected high-resolution satellite images taken at the road segments. Our results suggest that combining graph neural networks and satellite image embeddings can predict accident occurrences with an AUROC over 90%.

Algorithmic reasoning

Can neural networks learn to follow the behavior of an algorithm? We study this question using inputs from twelve classical algorithms, such as BFS, DFS, Dijkstra, Kruskal, Strongly Connected Components (SCC), and Floyd-Warshall. We train graph neural networks to predict the intermediate executions and final answers of these algorithms on Erdős-Rényi graphs with 16 nodes. Building on the multitask learning framework in the previous section, our contribution is a hierarchical, branching GNN for simultaneously learning multiple algorithmic reasoning tasks. We find that GNNs can learn simple algorithms such as BFS with over 99% accuracy, related algorithms like SCC with over 90% accuracy, while struggling with more complex algorithms such as Floyd-Warshall with just 63% accuracy. The accuracy for DFS is less than 40%, as with other baseline architectures, suggesting that the recursive search process in DFS is particularly difficult to learn. This work raises several questions regarding the learnability of an algorithm via neural networks and length generalization to larger graphs/sequences.

Incentive Ratio, Rank Aggregation

My earlier work has sought to quantify the extent in which selfish misreporting can distort the utility of a bidder in a Fisher market equilibrium game, where each bidder consumes a finite budget.

More recent work along this line involves studying ranking aggregation algorithms under incentive-compability, as well as scaling these insights to language models.