Papers
arxiv:2403.04652

Yi: Open Foundation Models by 01.AI

Published on Mar 7, 2024
Β· Submitted by akhaliq on Mar 8, 2024
#1 Paper of the day
Authors:
,
,
,
,
,
,
,
,
,
,
,
,
,
,

Abstract

We introduce the Yi model family, a series of language and multimodal models that demonstrate strong multi-dimensional capabilities. The Yi model family is based on 6B and 34B pretrained language models, then we extend them to chat models, 200K long context models, depth-upscaled models, and vision-language models. Our base models achieve strong performance on a wide range of benchmarks like MMLU, and our finetuned chat models deliver strong human preference rate on major evaluation platforms like AlpacaEval and Chatbot Arena. Building upon our scalable super-computing infrastructure and the classical transformer architecture, we attribute the performance of Yi models primarily to its data quality resulting from our data-engineering efforts. For pretraining, we construct 3.1 trillion tokens of English and Chinese corpora using a cascaded data deduplication and quality filtering pipeline. For finetuning, we polish a small scale (less than 10K) instruction dataset over multiple iterations such that every single instance has been verified directly by our machine learning engineers. For vision-language, we combine the chat language model with a vision transformer encoder and train the model to align visual representations to the semantic space of the language model. We further extend the context length to 200K through lightweight continual pretraining and demonstrate strong needle-in-a-haystack retrieval performance. We show that extending the depth of the pretrained checkpoint through continual pretraining further improves performance. We believe that given our current results, continuing to scale up model parameters using thoroughly optimized data will lead to even stronger frontier models.

Community

The shear amount of data curation devoted into the training pipeline...

Interesting discussion for their very impressive long context capability in this paper, I'm going to speculate on how they got to this:

nu3QHxnrKrvkyzL.png

Like many other folks (e.g. https://huggingface.co/papers/2402.10171), they believe that long context capability is latent in the model and just need to be teased out with some continued training on longer sequences and specific long-sequence tasks. The big part of their novelty here seems to be how they can train on sequence lengths of ~200K, and what they're training to do the extension. The only other public result similar to what they did (with similar setup too!) is LMW - https://huggingface.co/papers/2402.08268

Their big 3 things seem to be:

  1. Distributed training for long sequences
  2. Data engineering (overlooked by many others)
  3. Simple RoPE hack (ABF)

on Distributed training

We implement and improve computation-communication overlapping, sequence parallelism, and communication compression to support up to 200K context length continue pretraining and finetuning. Our method to scale the context length to 200K is solely based on engineering
...
We continue pretrain the full-attention model using sequence parallelism [43] and distributed attention. This is to say, we do not use any sparse or linear attention, but use a brute force implementation of the full attention.
...
We use and improve upon the following techniques to tackle the memory and communication restrictions:
(1) ZeRO-1 [60] ...
(2) tensor parallel combined with pipeline parallel [70] within each compute node to avoid inter-node communication bottleneck ...
(3) kernel fusion techniques like flash attention[15][14] ...
(4) topology-aware resource allocation (ranking strategy) to minimize the communication across different layers of switches, which is the limitation of a typical fat-tree-topology.

So, how are they doing this? It sounds like:

  1. [Engineering] Sequence parallelism Γ  la https://huggingface.co/papers/2105.13120 (this is confusingly not Ring Attention, even though it's another Ring Self-Attention)
    • Speculating: they hook up their nodes in a ring topology. Each node (i) holds a chunk of Q_i, K_i, and V_i
    • Step 1: they calculate the partial attention-score Q_iK_i^T
    • Step 2: they send/recv their key chunk to the next node (so this node now gets K_{i-1 mod N} with N total GPUs
    • Step 3: after every node calculates their Q_iK^T (N-1 rounds), it gets softmaxed to get similarity scores
    • Step 4: they start projecting the similarity scores onto V_i
    • Step 5: they do the same send/recv scheme for V_i as well
    • Note: this is different from Ring Attention in that this does not do the partial softmax and the fused MLP in the same round, hence it needs 2 rounds for attention (one for QK^T, and one for the V)
    • Another question I have: why not just scatter + gather the K and V directly? Why do N-1 rounds of point-to-point communications?
  2. [Engineering] Overlapping communication-vs-computation
    • Probably triple-buffering the send/recv/compute so that while you're computing Q_iK_j^T, you're also simultaneously receiving and sending as well
  3. [Engineering] Communication compression
    • This seems like a neat trick, I don't see the paper going deeper, and I'd love to find out more
  4. [Engineering] FlashAttention 2 (to reduce local memory use per device, and take advantage of kernel fusion)
    • (???) Honestly, I don't see how this can apply to Ring Self-Attention, since RSA must materialize the full Q_iK^T matrix, while FlashAttention 2 is designed to avoid that. Unless this was used during the initial pretraining, while RSA is used during long-context extension?
  5. Otherwise a standard Megatron setup w/ just pipeline + tensor parallelism intra-node (I'm guessing high BW), and sequence-parallelism inter-node (I'm guessing low BW)

on Data Engineering for long context

Seem pretty self-explanatory

To adapt the base model to longer context, we continue pretrain the model on 10B tokens from our pretraining data mixture with slightly upsampled long sequences, mostly from book. We observe that only 1-2B tokens is enough for the model to converge to low loss on 4K-200K length, and a lightweight finetuning further induces near-perfect long-context retrieval performance. Based on this observation, we tend to view that the capability of modeling longer dependency than the pretrained length (4K) is a intrinsic capability (rather than an being injected by post-train).

This is interesting, 1-2B tokens of continued training with a data mixture containing long-context samples allows the model to generalize on longer context. However, they still finetune on downstream tasks, such as recall/retrieval at long context as well. That second part seems to be something people miss doing often.

We continue pretrain the Yi 6B/ 34B base model on the data mixture of
(1). original pretraining data, as is introduced in section 2;
(2). length-upsampled long-context data, where the long documents are mostly from books;
(3). multi-document question-answering synthetic data, where we construct QA pairs where the answer
contains a recitation of the related paragraph before the answer.
Our data approach mostly follows the data engineering practice in Fu et al. [22] and Yu et al. [87]. We continue pretrain the model on 5B tokens with 4M batch size, which translate to 100 optimization steps. Aligning with the concurrent work from Fu et al. [22], we observe that such light-weight continue pretraining is already able to enable a strong performance on Needle-in-a-Haystack test, as we will show in Figure 6.

Somewhat conflicting numbers on the # of training tokens, but you can see their mixture and scheduling program for training. It's neat that they consider the task of making an LLM generalize to long context separate from making an LLM do well at long context downstream tasks, but +100 that's a critical distinction.

on Architectural changes

Very minor one that LWM and CodeLLaMA do as well:

We adjust the base frequency (RoPE ABF), introduced
in Xiong et al. [82], to support long context windows up to 200K where the base model itself is
trained on 4K context length

Seems like a base frequency of 10000000.0 is used according to https://huggingface.co/01-ai/Yi-6B-200K/blob/main/config.json

This is an automated message from the Librarian Bot. I found the following papers similar to this paper.

The following papers were recommended by the Semantic Scholar API

Please give a thumbs up to this comment if you found it helpful!

If you want recommendations for any Paper on Hugging Face checkout this Space

You can directly ask Librarian Bot for paper recommendations by tagging it in a comment: @librarian-bot recommend

Sign up or log in to comment

Models citing this paper 213

Browse 213 models citing this paper

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2403.04652 in a dataset README.md to link it from this page.

Spaces citing this paper 204

Collections including this paper 23