Approximating Language Model Training Data from Weights

1. Approximating Language Model Training Data from Weights

Some LLMs are open-weights, but not open-data. Given the access to an initial set of model parameters corresponding to the model state before finetuning, as well as knowledge of the optimizer used (e.g. SGD vs. Adam), we seek to approximate the training data used for finetuning.

Formally, we assume access to some training algorithm ฮค that solves an optimization problem:

๐œƒ=ฮค(๐ฟ,๐ท)=argmin๐œƒ๐ธ๐‘ฅโˆผ๐ท[๐ฟ(๐‘ฅ,๐œƒ)]

Given the finetuned model parameters ๐œƒ๐‘“, our aim is to find

๐ทโˆ—=argmin๐ทโ€–๐œƒ๐‘“โˆ’ฮค(๐ฟ,๐œƒ0,๐ท)โ€–=argmin๐ทโ€–๐œƒ๐‘“โˆ’argmin๐œƒ๐ธ๐‘ฅโˆผ๐ท[๐ฟ(๐‘ฅ,๐œƒ)]โ€–

We cannot optimize this objective directly because training a model on any candidate dataset is expensive and computing the loss requires a non-differentiable lookup operation to convert a token sequence to a sequence of dense embedding vectors, which means that typical dataset distillation approaches are no longer applicable.

1.1. Method: SELECT

We constrain the problem to data selection instead of data generation: given a large corpus of text data, we search for a small set of datapoints that, after training, produce a model close to the final model.

We can express this goal as a search for data ๐‘ฅ with a gradient that maximizes its projection onto the model diff ๐œƒ๐‘“โˆ’๐œƒ0.

๐‘ฅโˆ—=argmax๐‘ฅโˆˆ๐ท[โˆ‡๐‘ฅ๐ฟ(๐‘ฅ,๐œƒ0)โ‹…(๐œƒ๐‘“โˆ’๐œƒ0)]

A naive solution to this problem might be to take the examples with the top similarity with the parameter difference. However, in practice, this yields highly redundant samples, as it neglects to account for batch-level interactions; when training with stochastic gradient descent, we typically take steps using gradients summed across multiple examples.

In light of this information, we instead express our search as for the set of points that produces a total gradient pointing in the direction of the parameter difference:

argmax๐ตโІ๐ท[โˆ‘๐‘ฅโˆˆ๐ตโˆ‡๐‘ฅ๐ฟ(๐‘ฅ,๐œƒ0)โ‹…(๐œƒ๐‘“โˆ’๐œƒ0)]

Solving for ๐ต exactly requires enumerating all possible subsets of ๐ท and is generally intractable to solve in polynomial time. However, the batch search objective is submodular because it exhibits the diminishing returns property: the marginal gain of adding a new datapoint decreases as the batch grows. The submodularity is known to have an efficient, close-to-optimal greedy solution.

State-of-the-art dataset distillation approaches achieve more effective distillation with gradients that match trajectories of several final model checkpoints ๐œƒ๐‘—,๐‘—โˆˆ[1,๐‘ƒ]. This puts us at a significant disadvantage because examples' gradients at the beginning of training may point in a different direction later on during the optimization process. To make up for our lack of additional model checkpoints, we create synthetic checkpoints by linearly interpolating between the initial and final model:

๐œƒฬ‚๐‘—=๐‘—๐‘ƒ๐œƒ0+(1โˆ’๐‘—๐‘ƒ)๐œƒ๐‘“

where ๐‘ƒ is the desired number of synthetic checkpoints. We then search for the batch of examples with a gradient that is most aligned, on average, with the direction of the synthetic checkpoints:

argmax๐ตโІ๐ท[โˆ‘๐‘—=1๐‘ƒโˆ‘๐‘ฅโˆˆ๐ตโˆ‡๐‘ฅ๐ฟ(๐‘ฅ,๐œƒฬ‚๐‘—)โ‹…(๐œƒ๐‘“โˆ’๐œƒฬ‚๐‘—)]

Prior work has demonstrated that the gradient of the last layer of language model can be high-resolution enough for synthetic data generation. Since our approach requires per-example gradients, which are typically computationally expensive, we run backpropagation only for the last layer to save memory and reduce overall computation.

Storing all gradients in their original dimension requires |๐ท|ยท|โˆ‡โ„“| parameters, which can quickly become prohibitive. To address this, we leverage the classic Johnson-Lindenstrauss lemma, which guarantees that a set of points in ๐‘…๐‘› can be mapped to a lower-dimensional space ๐‘…๐‘˜ (for ๐‘˜โ‰ช๐‘›) while preserving inner products with high probability.

References

  1. Approximating Language Model Training Data from Weights