I just finished watching Andrej Karpathy’s video Let’s reproduce GPT-2 (124M). It is fascinating to break down complex LLM systems into such a simplified version so that everyone with basic machine learning knowledge can understand it. I’m not an LLM researcher, but it doesn’t seem difficult to break down some open-source LLMs since there are dozens of public codes and blogs available. In this blog, I will show a vanilla Llama 3 implementation, which loads the pre-trained weights into the network, ensuring the output matches the output from the Hugging Face implementation. I’m doing all this just for fun. However, if there are any mistakes, don’t hesitate to let me know.

HuggingFace LLama 3

First, we’ll work with the Hugging Face GPT. The goals are:

  1. To know the exact network layers of the Llama 3 pre-trained model, so we can build a compatible model.
  2. To get the generated answer from the original Llama 3 model, so we can compare our results with it and verify our implementation.

First, we print the names of each Llama 3 layer:

Then, we get an example output from the Hugging Face Llama 3.

import transformers
import torch
from transformers import set_seed

pipeline = transformers.pipeline(
  "text-generation",
  model="meta-llama/Meta-Llama-3-8B-Instruct",
  device="cuda",
  use_cache=False,
)

set_seed(42)
output = pipeline("Hello, I'm a language model,", max_length=100, num_return_sequences=1, truncation=True, do_sample=False)
print(output)

# --------------------------------------------------------------------------------------------------- #
# output:
[{'generated_text': "Hello, I'm a language model, and I'm here to help you with your questions. I can provide information on a wide range of topics, from science and history to entertainment and culture. I can also help you with language-related tasks, such as grammar and vocabulary practice, and even assist with writing and proofreading. So, what's on your mind? What do you want to talk about or ask me? I'm all ears!"}]

My vanilla LLama 3

Here is my vanilla Llama 3. Some code is learned from or directly borrowed from the HuggingFace Llama 3 implementation or the Meta Llama 3 GitHub repository.

from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch import nn


from tokenizer import Tokenizer
from torch.nn import functional as F


@dataclass
class ModelArgs:
    dim: int = 4096
    n_kv_heads: int = 8
    vocab_size: int = 128256  # 128000 BPE merges + 256 bytes tokens
    n_layers: int = 32
    n_heads: int = 32
    ffn_dim_multiplier: float = 1.3
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    norm_eps: float = 1e-5
    rope_theta: float = 500000

    max_seq_len: int = 2048


class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        super().__init__()
        self.scaling_factor = scaling_factor
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # For BC we register cos and sin cached
        self.max_seq_len_cached = max_position_embeddings

    @torch.no_grad()
    def forward(self, x, position_ids):
        # x: [bs, num_attention_heads, seq_len, head_size]
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


class Attention(nn.Module):
    def __init__(self, model_args: ModelArgs) -> None:
        super().__init__()
        self.dim, self.n_heads = model_args.dim, model_args.n_heads
        self.head_dim = model_args.dim // model_args.n_heads
        self.n_kv_heads = model_args.n_kv_heads
        self.n_rep = self.n_heads // self.n_kv_heads
        self.q_proj = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)

        self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=model_args.max_seq_len, base=model_args.rope_theta)
    
    def forward(self, x, pos_ids):
        bs, seqlen, _ = x.shape
        xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)

        xq = xq.view(bs, seqlen, self.n_heads, self.head_dim).transpose(1, 2)
        xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim).transpose(1, 2)
        xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(xv, pos_ids)
        xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)

        # repeat k/v heads if n_kv_heads < n_heads
        xk = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        xv = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

        # we use casual mask for training
        output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
        output = output.transpose(
            1, 2
        ).contiguous()  # (bs, seqlen, n_local_heads, head_dim)
        output = output.view(bs, seqlen, -1)
        return self.o_proj(output)


class MLP(nn.Module):
    def __init__(self, model_args: ModelArgs) -> None:
        super().__init__()
        hidden_dim = int(2 * model_args.dim * 4 / 3)
        hidden_dim = int(model_args.ffn_dim_multiplier * hidden_dim)
        hidden_dim = model_args.multiple_of * ((hidden_dim + model_args.multiple_of - 1) // model_args.multiple_of)
        self.gate_proj = nn.Linear(model_args.dim, hidden_dim, bias=False)
        self.up_proj = nn.Linear(model_args.dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, model_args.dim, bias=False)

    def forward(self, x):
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, model_args: ModelArgs) -> None:
        super().__init__()
        self.self_attn = Attention(model_args)
        self.mlp = MLP(model_args)
        self.input_layernorm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
        self.post_attention_layernorm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
    
    def forward(self, x, pos_ids):
        h = x + self.self_attn(self.input_layernorm(x), pos_ids)
        out = h + self.mlp(self.post_attention_layernorm(h))
        return out


class GPT(nn.Module):
    def __init__(self, model_args: ModelArgs) -> None:
        super().__init__()
        self.embed_tokens = nn.Embedding(model_args.vocab_size, model_args.dim)

        self.layers = nn.ModuleDict()
        for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
        
        self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
        self.lm_head = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)

    def forward(self, x):
        bs, seqlen = x.shape
        pos_ids = torch.arange(seqlen, device=x.device).unsqueeze(0).expand(bs, -1)

        h = self.embed_tokens(x)

        for layer in self.layers.values():
            # h = layer(h, self.freqs_cis)
            h = layer(h, pos_ids)
        
        h = self.norm(h)
        output = self.lm_head(h)
        return output

    @classmethod
    def from_pretrained(cls, model_type):
        config = ModelArgs()
        model = GPT(config)
        sd = model.state_dict()
        sd_keys = sd.keys()

        # init a huggingface/transformers model
        from transformers import AutoModelForCausalLM
        model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        sd_keys_hf = sd_hf.keys()
        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            # vanilla copy over the other parameters
            assert sd_hf[k].shape == sd[k.replace('model.', '')].shape
            with torch.no_grad():
                sd[k.replace('model.', '')].copy_(sd_hf[k])

        return model


# --------------------------------------------------------------------------------------------------- #
num_return_sequences = 1
max_length = 100

# model = GPT(ModelArgs())
model = GPT.from_pretrained("llama3")

# print model layers
sd = model.state_dict()
for k, v in sd.items():
    print(k, v.shape)

model.eval()
model.cuda()

# prefix tokens
enc = Tokenizer(model_path="llama3/tokenizer.model")
tokens = enc.encode("Hello, I'm a language model,", bos=False, eos=False)

tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
x = tokens.to('cuda')

torch.manual_seed(42)
torch.cuda.manual_seed(42)
while x.size(1) < max_length:
    with torch.no_grad():
        logits = model(x)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        topk_probs, topk_indices = torch.topk(probs, 1, dim=-1)
        ix = torch.multinomial(topk_probs, num_samples=1)
        xcol = torch.gather(topk_indices, -1, ix)
        x = torch.cat((x, xcol), dim=1)

for i in range(num_return_sequences):
    tokens = x[i, :max_length].tolist()
    try:
        # Try to find the index of token 128009
        index = tokens.index(128009)
        # Cut off all tokens from this index onward
        tokens = tokens[:index]
    except ValueError:
        # Handle the case where 128009 is not in the list
        print("Token 128009 is not in the list. No changes made.")
    decoded = enc.decode(tokens)
    print(">", decoded)

Results: