Overview

I made PicoGPT as a way to do small experiments on gpt2 small sized architectures. Primarily this comes from experience; even top labs tend to do a lot of 1 8xH100 node experiments before scaling. It is loosely based on the NanoGPT speedrun repository, but I removed some features I think would not scale well, as the purpose of PicoGPT was not speedrunning, but instead doing experiments to form theories on how to make better models at larger scales. This is not to say that certain architectural changes NanoGPT made won’t scale, but moreso that I didn’t have evidence of scaling, so I decided to keep the experimental repository simple. Also, if I were speedrunning, then that severely limits the types of experiments in the design space (for example, MoEs are out of the question for speedrunning due to communication overhead (2 gathers, scatter)).
The initial PicoGPT experiments include:

Multi-latent Attention
-reduces parameter count, kv joint embedding, shared RoPE across heads, partial NoPE
Flex Attention
-allows to treat multiple documents as a single sequence with document masking (no padding necessary), sliding window of context length, causal masking
Relu**2
-more sample efficient than GLU variants, even for larger models?
RoPE
feed forward MoEs with expert parallelism (experiment 2)
-increases expressivity for MLPs with limited inference cost I'll probably do many more experiments on this repo, the design space is huge. I often think about: [SSMs, hybrid models, parameter sharing], which readily fit into model experiments at this size.
Things I did not include from NanoGPT:

tanh logit soft capping (gemma)
long short sliding window attention (gemma)
QK norm
encoder/decoder u-net architecture (nanoGPT)
Muon optimizer
Window warmup (it was found by NanoGPT that slowly increasing context length during training to max sped up training, intuitively I buy this, but didn’t include it)
The initial architecture is slightly more sample efficient than gpt2 on 5B tokens of fineweb (did not do a full fineweb10B token run). It has a lot of room for growth of course.

Initial Hypothesis

My initial idea was as follows: Multi latent attention is really interesting because of the parameter savings from low rank projections, as well as the inference compute savings of caching a low rank RoPE key tensor and a low rank joint kv embedding. I wanted to run some experiments to loosely benchmark the capabilities of this architectural change and see if I could potentially improve upon it at a smaller scale.
Pretty much every other company to my knowledge was using context parallel instead of being able to run only data parallel for attention computation. Deepseek was already paying the cost of having many MoE’s with expert parallel. Potentially due to the low-rank nature of the key and value projections, it might be worthwhile to add MoE’s for added expressivity here.

Background/Relevant Papers

There is the obvious work in Deepseek-V3 and Deepseek-V2 which you should read immediately. In those works they describe MLA, it’s really not a complicated architectural intervention, but it works. It’s probably important to read some of the other work with low rank attention such as Grouped Query Attention. In GQA they do a similar low rank projection such that they share key and value heads for multiple query heads. This is not as performant as MLA which suggests some interesting theories with regards to how attention works.
There is the obvious work in Deepseek-V3 and Deepseek-V2 which you should read immediately. In those works they describe MLA; it’s really not a complicated architectural intervention, but it works. It’s probably important to read some of the other work with low rank attention such as Grouped Query Attention (Gemma, Qwen). In GQA they do a similar low rank projection such that they share key and value heads for multiple query heads. This is not as performant as MLA which suggests some interesting theories with regards to how attention works.
I initially got the idea to add MoEs to the attention block from reading the MoEUT paper, as well as the Switch-Head paper, since they got somewhat good results from adding MoEs to value and output projections. Their paper found that messing with the inner product by adding MoEs to query and key projections is problematic.

Multi Latent Attention Implementation

Of course, from Deepseek-v2/v3 I had to dramatically scale down the MLA implementation. There is a large design space here as well, and I highly doubt I found the optimal configuration given my compute budget. MLA (see code example at bottom of page, or see github repo) Deepseek's implementation does several projections from the initial sequence of shape (B,S,7168). The query tensor is first projected down to a low rank latent of 1536 (Cq) before up-projected to 128*128 (num_heads*head_dim). This part is completely without positional encodings (NoPE). Obviously they need positional encodings, but they realized they can share a small extremely RoPE(low rank tensor) across all heads. They project from the original sequence (B,S,7168)->(B,S,qr=64), then RoPE. They repeat this across all heads and concatenate such that the head dimension of the inner product is (head_dim+qr).
For keys, the same RoPE technique is done (with the same dims as the RoPE queries for the inner product). For keys and values they do the same initial down projection from the orignal sequence, but they not only parameter share, they share the activations as a joint embedding: x->kv_low_rank. This allows them to cache this single low rank tensor for both keys and values. The initial down projection is even lower rank than the queries as well, which saves parameters for training and caching for inference x->kv_low_rank: (B,S,512). They can up project from the 512 super low rank joint embedding space for the NoPE keys and value projections. If you do the napkin math it's like 100M params saved per block.

In my implementation I had to use a dramatically different scale. Because of compute, there are many combinations that I couldn't test, but I thought this to be a second order problem; I was happy to get some training runs in regardless. I use the same low rank dimension for both queries and keys down projecting from 768 (gpt2) to 384, then I up projected back up to the orignal 768 to be distributed amongst 12 heads. I actually kept the shared RoPE embed_dim from deepseek of 64, which makes half of my attention inner product shared across heads. The only real difference that isn't scaled down perfectly from deepseek's implementation is the value up projection had to be double the original embedding dimension to work well with flex-attention. You can easily get around this by writing a custom kernel, but that would probably overcomplicate this project, and I didn't mind the boost in expressivity after attention is computed anyways.

Experiments

The first experiment was a baseline gpt2 sized model with MLA replacing all MHA blocks.
Config:
768 dim embed, 12 layers, 12 heads
384 dim query and key low_rank
1536 value up projection
1024 sequence length (sliding window and document masking instead of padding for speed)
5B tokens of webtext
This resulted in: val_loss This is an improvement on naive gpt2 at 5B tokens trained, since GPT2 only makes it to about 3.4 val_loss at 5B tokens. It is not conclusive. However, to reduce the parameters and get increased sample efficiency shows that it is a reasonable direction.

Now, with not very much compute budget remaining, I decided I had a decent enough baseline to test my MoE ideas. It seemed somewhat obvious. Low-rank projections reduce expressivity by bottlenecking information, if I add a load-balanced MoE to value and out projections it's free lunch. This ended up not being the case. You can try it yourself, but even with load-balancing the model is unable to learn effectively. It saturates quite early, within 2B tokens.
This could obviously be because of the sample efficiency problems of training MoEs, which makes my 5B token test an unfair comparison, but to saturate early with load-balancing is an indication of something theoretical. I think perhaps attention does not need more expressivity. The Switch-Head paper originally found that adding MoEs prior to the inner-product is deterimental, and the value, out projections are just linear projections. How much expressivity is there even to gain by using a linear MoE?

If our only option is to modify the output of attention, and the MLPs come right after, we might as well just add the MoEs and gain back the expressivity there. This ends up being exactly what Deepseek does. I still ran into sample-effiency issues with training an MoE on 5B tokens of text, but got no saturation and it seemed promising to beat out my MLA model after ~10B tokens of text. However, given my GPU poor status, I was not going to wait that long. If you have an H100 Node I implore you to try it out! The code is available as it's own model+train file. Should do about 165000 tok/s with expert parallel.

Discussion

Overall, I seem to have reverse engineered the reason that Switch-head is getting less implementations, and why Deepseek did not use MoEs everywhere when they could. I don't think this is as interesting as the implication of MLA. The idea that low-rank linear projections are sufficient enough to augment the input sequence for queries, keys and values in attention is extremely interesting.
Perhaps the idea is that the lower dimensional hyper plane is not as important as splitting heads, which is likely why deepseek decided to up project to an even higher embedding dim. It's like giving fine-scale resolution to the low rank hyperplane. It's honestly amazing that it even works at gpt2 scale.
If I had more compute I would have run on more training tokens + evals, but I sadly ran out of budget. If you are willing to fund experiments of this size, please contact me.

 #MultiLatentAttention
class MultiLatentAttention(nn.Module):
    def __init__(self,hidden_dim,num_heads=12,low_rank=2,block_size=128,max_seq_len=1024):
        super().__init__()
        self.num_heads=num_heads
        self.head_dim=hidden_dim//num_heads
        self.block_size=block_size
        self.max_seq_len=max_seq_len
            #assert hidden_dim//num_heads
        #downproj for q
        self.qd_proj=nn.Linear(hidden_dim,hidden_dim//low_rank) 
        self.qu_proj=nn.Linear(hidden_dim//low_rank,hidden_dim)

        self.qr_proj=nn.Linear(hidden_dim,self.head_dim)
        #shared downproj for k,v
        self.kvd=nn.Linear(hidden_dim,hidden_dim//low_rank)
        self.v_up_proj=nn.Linear(hidden_dim//low_rank,hidden_dim*2)
        self.k_up_proj=nn.Linear(hidden_dim//low_rank,hidden_dim)

        self.kr_proj=nn.Linear(hidden_dim,self.head_dim)
        #output proj
        self.o_proj = nn.Linear(hidden_dim*2, hidden_dim)

        self.rope=RotaryPositionEmbedding(self.head_dim)
        self.scale = (2*self.head_dim) ** -0.5 

        self.world_size=torch.distributed.get_world_size()

    def forward(self, x,token_seq):
        #layer norm prior to input
        B, N,dim = x.shape
        assert B == 1, "Must use batch size = 1 for FlexAttention"


        # query projections
        qd=self.qd_proj(x) #B,N,low_rank_dim
        qr=self.qr_proj(x).unsqueeze(2)# B,N,1,head_dim/2
        qr=qr.expand(-1,-1,self.num_heads,-1).permute(0,2,1,3) #B,num_heads,seq_len,head_dim//2
        qr=self.rope(qr)
        q=self.qu_proj(qd) #B,N,dim
        q=q.reshape(B,N,self.num_heads,self.head_dim).permute(0,2,1,3)
        q=torch.cat((q,qr),dim=-1) #B,num_heads,seq_len,head_dim


        #keys
        low_rank_kv=self.kvd(x) #B,S,compressed_dim
        k=self.k_up_proj(low_rank_kv)
        kr=self.kr_proj(x).unsqueeze(2)
        kr=kr.expand(-1,-1,self.num_heads,-1).permute(0,2,1,3)
        kr=self.rope(kr)
        k= k.reshape(B,N,self.num_heads,self.head_dim).permute(0,2,1,3)
        k=torch.cat((k,kr),dim=-1) #B,num_heads,seq_len,head_dim

        #values

        v=self.v_up_proj(low_rank_kv) 
        v=v.reshape(B,N,self.num_heads,(self.head_dim*2)).permute(0,2,1,3)


        docs = (token_seq == 50256).cumsum(0)
        def document_causal_mask(b, h, q_idx, kv_idx):
            causal_mask = q_idx >= kv_idx
            document_mask = docs[q_idx] == docs[kv_idx]
            window_mask = q_idx - kv_idx < 1024
            return causal_mask & document_mask & window_mask

        S = len(token_seq)
        block_mask = create_block_mask(document_causal_mask, None, None, S, S, device="cuda", _compile=True)

        x = flex_attention(q, k, v, block_mask=block_mask, scale=self.scale)
        x = x.transpose(1, 2).reshape(B, N, -1)

        x=self.o_proj(x)
        return x