LLaMA 3是继Mistral之后最有前途的开源模型之一,可以解决各种任务。现在 LLaMA-3 已经发布,我们将以更简单的方式重新创建它。
本文中不会使用GPU,但您至少需要17 GB 的 RAM,因为我们将加载一些大小超过 15 GB 的文件。如果这对您来说是个问题,您可以使用Kaggle作为解决方案。由于我们不需要 GPU,Kaggle 提供了 30 GB 的 RAM,同时仅使用 CPU 内核作为加速器。
好消息是,我们不会使用面向对象编程 (OOP) 编码,而只会使用普通的 Python 编程。但是,您应该对神经网络和 Transformer 架构有基本的了解。这是阅读本博客所需的唯一两个先决条件。
LLaMA 2 和 LLaMA 3 之间的区别
在了解技术细节之前,您必须知道的第一件事是LLaMA 3 的整个架构与 LLaMA 2 相同。因此,如果您尚未了解 LLaMA 3 的技术细节,那么阅读本博客对您来说不成问题。即使您不了解 LLaMA 2 架构,也不用担心,我们还将概述其技术细节。无论如何,本博客都是为您设计的。
以下是有关 LLaMA 2 和 LLaMA 3 的一些要点。如果您已经熟悉它们的架构:
了解 LLaMA 3 的 Transformer 架构
在深入编写代码之前,了解 LLaMA 3 的架构非常重要。为了更好地直观理解,下面是 vanilla Transformer、LLaMA 2/3 和 Mistral 的比较图。
让我们更详细地了解一下 LLaMA 3 的最重要组件:
1. 使用 RMSNorm 进行预规范化:
LLaMA 3 方法与 LLaMA 2 相同,使用一种称为 RMNSNorm 的技术来规范化每个 Transformer 子层的输入。
这就是使用 RMSNorm 进行预规范化对 ChatGPT 等大型语言模型 (LLM) 发挥作用的地方。这就像根据每个章节的重要性为其分配权重。对主题至关重要的章节获得更高的权重,而不太重要的章节获得较低的权重。
类似地,使用 RMSNorm 进行预规范化可帮助 LLM 确定文本中哪些部分对于理解上下文和含义更为关键。它为基本元素分配较高的权重,为不太重要的元素分配较低的权重,确保模型将注意力集中在最需要准确理解的地方。感兴趣的读者可以在此处探索 RMSNorm 的详细实现。
2. SwiGLU 激活函数:
LLaMA 引入了 SwiGLU 激活函数,从 PaLM 中汲取灵感。
SwiGLU 就像是 ChatGPT 等大型语言模型 (LLM) 的魔法笔。在生成文本之前,SwiGLU 会根据每个单词或短语与上下文的相关性调整其重要性。就像魔法笔会调整书写的大小和样式一样,SwiGLU 会调整每个单词或短语的强调程度。
SwiGLU:GLU 变体改进了 Transformer(https://kikaben.com/swiglu-2020/)
因此,当 LLM 生成文本时,它可以更加突出重要部分,使它们更加引人注目,并确保它们对文本的整体理解做出更多贡献。通过这种方式,SwiGLU 帮助 LLM 生成更清晰、更易于理解的文本,就像魔术笔帮助您在白板上为学生创建更清晰的解释一样。有关 SwiGLU 的更多详细信息,请参阅相关论文。
3. 旋转嵌入(RoPE):
旋转嵌入(RoPE)是 LLaMA 3 中使用的一种位置嵌入。
ROPE 就像是一种特殊的座位安排,允许学生旋转和改变位置,同时仍保持彼此的相对位置。学生不再固定在一个地方,而是可以以圆周运动的方式四处移动,从而实现更流畅的互动。
在这种情况下,每个学生代表文本序列中的一个单词或标记,他们的位置与他们在序列中的位置相对应。就像 ROPE 允许学生旋转和改变位置一样,ROPE 允许文本序列中单词的位置嵌入根据它们之间的相对位置动态变化。
因此,在处理文本时,ROPE 不会将位置嵌入视为固定和静态的,而是引入了旋转方面,从而允许更灵活的表示来捕捉序列中单词之间的动态关系。这种灵活性有助于 ChatGPT 等模型更好地理解和生成自然流动且保持一致性的文本,类似于动态座位安排如何促进课堂上更具互动性的讨论。对数学细节感兴趣的人可以参考RoPE 论文。
LLaMA 3 使用的是 OpenAI 推出的 tiktoken 库中的字节对编码 (BPE),而 LLaMA 2 标记器 BPE 则基于 sentencepiece 库。它们之间略有不同,但首先,让我们了解一下 BPE 到底是什么。
- 首先,我们找到出现频率最高的一对连续字符。在本例中,出现频率最高的一对是“bc”,频率为 2。然后,我们合并这对字符,创建一个新的子词单元“bc”。合并后,我们更新频率计数以反映新的子词单元。更新后的频率为 {“a”: 1, “b”: 2, “c”: 2, “d”: 2, “e”: 1, “bc”: 2}。我们将新的子词单元“bc”添加到词汇表中,现在词汇表变为 {“a”、“b”、“c”、“d”、“e”、“bc”}。
- 我们重复这个过程。下一个最常见的对是“cd”。我们合并“cd”以形成一个新的子词单元“cd”,并更新频率计数。更新后的频率为{“a”:1,“b”:2,“c”:1,“d”:1,“e”:1,“bc”:2,“cd”:2}。我们将“cd”添加到词汇表中,结果为{“a”,“b”,“c”,“d”,“e”,“bc”,“cd”}。
- 继续这个过程,下一个频繁出现的单词对是“de”。我们合并“de”形成子词单元“de”,并将频率计数更新为{“a”: 1, “b”: 2, “c”: 1, “d”: 1, “e”: 0, “bc”: 2, “cd”: 1, “de”: 1}。我们将“de”添加到词汇表中,使其成为{“a”、“b”、“c”、“d”、“e”、“bc”、“cd”、“de”}。
- 接下来,我们找到“ab”作为最常见的一对。我们合并“ab”以形成子词单元“ab”,并将频率计数更新为{“a”:0,“b”:1,“c”:1,“d”:1,“e”:0,“bc”:2,“cd”:1,“de”:1,“ab”:1}。我们将“ab”添加到词汇表中,它变成{“a”,“b”,“c”,“d”,“e”,“bc”,“cd”,“de”,“ab”}。
- 然后,下一个频繁出现的对是“bcd”。我们合并“bcd”以形成子词单元“bcd”,并将频率计数更新为{“a”:0,“b”:0,“c”:0,“d”:0,“e”:0,“bc”:1,“cd”:0,“de”:1,“ab”:1,“bcd”:1}。我们将“bcd”添加到词汇表中,得到{“a”,“b”,“c”,“d”,“e”,“bc”,“cd”,“de”,“ab”,“bcd”}。
- 最后,出现频率最高的一对是“cde”。我们合并“cde”以形成子词单元“cde”,并将频率计数更新为{“a”:0,“b”:0,“c”:0,“d”:0,“e”:0,“bc”:1,“cd”:0,“de”:0,“ab”:1,“bcd”:1,“cde”:1}。我们将“cde”添加到词汇表中,使其成为{“a”,“b”,“c”,“d”,“e”,“bc”,“cd”,“de”,“ab”,“bcd”,“cde”}。
这种技术可以提高 LLM 的性能,并处理罕见和词汇表之外的单词。TikToken BPE 和 sentencepiece BPE 之间的最大区别在于,如果整个单词已知,TikToken BPE 并不总是将单词拆分成更小的部分。例如,如果“hugging”在词汇表中,它会保留为一个标记,而不是拆分成[“hug”,“ging”]。
我们将使用少量 Python 库,但最好安装它们以避免遇到“未找到模块”错误。
pip install sentencepiece tiktoken torch blobfile matplotlib huggingface_hub
安装所需的库后,我们需要下载一些文件。由于我们要复制llama-3–8B的架构,因此您必须在 HuggingFace 上拥有一个帐户。此外,由于 llama-3 是一个门控模型,因此您必须接受其条款和条件才能访问模型内容。
(选项 1:手动)从此链接转到 llama-3–8B HF 目录并手动下载这三个文件中的每一个。
下载 LLaMA-3 配置文件
(选项 2:编码)我们可以使用之前安装的hugging_face库来下载所有这些文件。但是,首先,我们需要使用我们的 HF 令牌在我们的工作笔记本中登录 HuggingFace Hub。您可以创建一个新令牌或从此链接访问它。
# Import the `notebook_login` function from the `huggingface_hub` module. from huggingface_hub import notebook_login # Execute the `notebook_login` function to log in to the Hugging Face Hub. notebook_login()
运行此单元格后,它会要求您输入令牌。如果登录时出现错误,请重试,但请确保取消选中。add token as git credential.
之后,我们只需运行一个简单的 Python 代码即可下载构成 llama-3–8B 架构主干的三个文件。
from huggingface_hub import hf_hub_download # Define the repository information repo_id = "meta-llama/Meta-Llama-3-8B" subfolder = "original" # Specify the subfolder within the repository # List of filenames to download filenames = ["params.json", "tokenizer.model", "consolidated.00.pth"] # Specify the directory where you want to save the downloaded files save_directory = "llama-3-8B/" # Replace with your desired path # Download each file for filename in filenames: hf_hub_download( repo_id=repo_id, # Repository ID filename=filename, # Name of the file to download subfolder=subfolder, # Subfolder within the repository local_dir=save_directory # Directory to save the downloaded file )
# Tokenization library import tiktoken # BPE loading function from tiktoken.load import load_tiktoken_bpe # PyTorch library import torch # JSON handling import json
由于我们的目标是精确复制 llama-3,这意味着我们的输入文本必须产生有意义的输出。例如,如果我们的输入是“太阳的颜色是?”,则输出必须是“白色”。要实现这一点需要在大型数据集上训练我们的 LLM,这需要很高的计算能力,这对我们来说是不可行的。
不过,Meta 已经公开发布了他们的 llama-3 架构文件,或者更复杂的术语,他们的预训练权重,以供使用。我们刚刚下载了这些文件,这样我们就可以复制他们的架构,而无需训练或大量数据集。一切都已准备就绪,我们只需在正确的位置使用正确的组件。
tokenizer.model —正如我们之前所讨论的,LLaMA-3 使用来自 tiktoken 的字节对编码 (BPE) 标记器,该标记器在包含 15 万亿个标记的数据集上进行训练 — 比 LLaMA-2 使用的数据集大 7 倍。让我们加载此文件并查看它包含的内容。
# Loading the tokenizer from llama-3-8B tokenizer_model = load_tiktoken_bpe("tokenizer.model") # Get the length of the tokenizer model len(tokenizer_model) # OUTPUT: 128000 # Get the type of the `tokenizer_model` object. type(tokenizer_model) # OUTPUT: dictionary
# Printing the first 10 items of tokenizer model dict(list(tokenizer_model.items())[5600:5610]) #### OUTPUT #### { b'mitted': 5600, b" $('#": 5601, b' saw': 5602, b' approach': 5603, b'ICE': 5604, b' saying': 5605, b' anyone': 5606, b'meta': 5607, b'SD': 5608, b' song': 5609 } #### OUTPUT ####
当我们从中打印 10 个随机项目时,您将看到使用 BPE 算法形成的字符串,类似于我们之前讨论的示例。键表示来自 BPE 训练的字节序列,而值表示基于频率的合并排名。
solidated.00.pth —包含 Llama-3–8B 的学习参数(权重)。这些参数包括有关模型如何理解和处理语言的信息,例如它如何表示标记、计算注意力、执行前馈转换以及规范化其输出。
# Loading a PyTorch model of LLaMA-3-8B model = torch.load("consolidated.00.pth") # printing first 11 layers of the architecture list(model.keys())[:11] #### OUTPUT #### [ 'tok_embeddings.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wk.weight', 'layers.0.attention.wv.weight', 'layers.0.attention.wo.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.feed_forward.w2.weight', 'layers.0.attention_norm.weight', 'layers.0.ffn_norm.weight', 'layers.1.attention.wq.weight', ] #### OUTPUT ####
如果您熟悉 Transformer 架构,那么您应该已经了解查询、关键矩阵等。稍后,我们将使用这些层/权重在 Llama-3 架构中创建此类矩阵。
# Opening the parameters JSON file with open("params.json", "r") as f: config = json.load(f) # Printing the content print(config) #### OUTPUT #### { 'dim': 4096, 'n_layers': 32, 'n_heads': 32, 'n_kv_heads': 8, 'vocab_size': 128256, 'multiple_of': 1024, 'ffn_dim_multiplier': 1.3, 'norm_eps': 1e-05, 'rope_theta': 500000.0 } #### OUTPUT ####
这些值将帮助我们通过指定诸如头部数量、嵌入向量的维度等细节来复制 Llama-3 架构。
# Dimension dim = config["dim"] # Layers n_layers = config["n_layers"] # Heads n_heads = config["n_heads"] # KV_heads n_kv_heads = config["n_kv_heads"] # Vocabulary vocab_size = config["vocab_size"] # Multiple multiple_of = config["multiple_of"] # Multiplier ffn_dim_multiplier = config["ffn_dim_multiplier"] # Epsilon norm_eps = config["norm_eps"] # RoPE rope_theta = torch.tensor(config["rope_theta"])
现在我们有了标记器模型、包含权重的架构模型和配置参数,让我们从头开始编写我们自己的 Llama-3。
special_tokens = [ "<|begin_of_text|>", # Marks the beginning of a text sequence. "<|end_of_text|>", # Marks the end of a text sequence. "<|reserved_special_token_0|>", # Reserved for future use. "<|reserved_special_token_1|>", # Reserved for future use. "<|reserved_special_token_2|>", # Reserved for future use. "<|reserved_special_token_3|>", # Reserved for future use. "<|start_header_id|>", # Indicates the start of a header ID. "<|end_header_id|>", # Indicates the end of a header ID. "<|reserved_special_token_4|>", # Reserved for future use. "<|eot_id|>", # Marks the end of a turn (in a conversational context). ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)] # A large set of tokens reserved for future use.
# patterns based on which text will be break into tokens tokenize_breaker = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
我们需要使用 TikToken BPE 编写一个简单的 tokenizer 函数,该函数接受三个输入:tokenizer_model、tokenize_breaker 和 special_tokens。此函数将相应地对我们的输入文本进行编码/解码。
# Initialize tokenizer with specified parameters tokenizer = tiktoken.Encoding( # make sure to set path to tokenizer.model file name = "tokenizer.model", # Define tokenization pattern string pat_str = tokenize_breaker, # Assign BPE mergeable ranks from tokenizer_model of LLaMA-3 mergeable_ranks = tokenizer_model, # Set special tokens with indices special_tokens={token: len(tokenizer_model) + i for i, token in enumerate(special_tokens)}, ) # Encode "hello world!" and decode tokens to string tokenizer.decode(tokenizer.encode("hello world!")) #### OUTPUT #### hello world! #### OUTPUT ####
为了验证我们的编码器函数方法是否正常工作,我们将“Hello World”传递给它。首先,它对文本进行编码,将其转换为数值。然后,它将其解码回文本,结果为“hello world!”。这证实该函数正常工作。让我们标记我们的输入。
# input prompt prompt = "the answer to the ultimate question of life, the universe, and everything is " # Encode the prompt using the tokenizer and prepend a special token (128000) tokens = [128000] + tokenizer.encode(prompt) print(tokens) # Print the encoded tokens # Convert the list of tokens into a PyTorch tensor tokens = torch.tensor(tokens) # Decode each token back into its corresponding string prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens] print(prompt_split_as_tokens) # Print the decoded tokens #### OUTPUT #### [128000, 1820, 4320, 311, ... ] ['<|begin_of_text|>', 'the', ' answer', ' to', ... ] #### OUTPUT ####
# checking dimension of input vector len(tokens) #### OUTPUT #### 17 #### OUTPUT #### # checking dimension of embedding vector from llama-3 architecture print(dim) #### OUTPUT #### 4096 #### OUTPUT ####
我们的输入向量目前的尺寸为 (17×1),需要转换为每个标记化单词的嵌入。这意味着我们的 (17×1) 标记将变为 (17×4096),其中每个标记都有相应的长度为 4096 的嵌入。
# Define embedding layer with vocab size and embedding dimension embedding_layer = torch.nn.Embedding(vocab_size, dim) # Copy pre-trained token embeddings to the embedding layer embedding_layer.weight.data.copy_(model["tok_embeddings.weight"]) # Get token embeddings for given tokens, converting to torch.bfloat16 format token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16) # Print shape of resulting token embeddings token_embeddings_unnormalized.shape #### OUTPUT #### torch.Size([17, 4096]) #### OUTPUT ####
使用 RMSNorm 进行规范化
我们将使用之前看到的 RMSNorm 相同公式对输入向量进行标准化,以确保我们的输入是标准化的。
# Calculating RMSNorm def rms_norm(tensor, norm_weights): # Calculate the mean of the square of tensor values along the last dimension squared_mean = tensor.pow(2).mean(-1, keepdim=True) # Add a small value to avoid division by zero normalized = torch.rsqrt(squared_mean + norm_eps) # Multiply normalized tensor by the provided normalization weights return (tensor * normalized) * norm_weights
我们将使用来自 layer_0 的注意力权重来规范化未规范化的嵌入。使用 layer_0 的原因是我们现在正在创建 LLaMA-3 Transformer 架构的第一层。
# using RMS normalization and provided normalization weights token_embeddings = rms_norm(token_embeddings_unnormalized, model["layers.0.attention_norm.weight"]) # Print the shape of the resulting token embeddings token_embeddings.shape #### OUTPUT #### torch.Size([17, 4096]) #### OUTPUT ####
# Print the shapes of different weights print( # Query weight shape model["layers.0.attention.wq.weight"].shape, # Key weight shape model["layers.0.attention.wk.weight"].shape, # Value weight shape model["layers.0.attention.wv.weight"].shape, # Output weight shape model["layers.0.attention.wo.weight"].shape ) #### OUTPUT #### torch.Size([4096, 4096]) # Query weight dimension torch.Size([1024, 4096]) # Key weight dimension torch.Size([1024, 4096]) # Value weight dimension torch.Size([4096, 4096]) # Output weight dimension #### OUTPUT ####
# Retrieve query weight for the first layer of attention q_layer0 = model["layers.0.attention.wq.weight"] # Calculate dimension per head head_dim = q_layer0.shape[0] // n_heads # Reshape query weight to separate heads q_layer0 = q_layer0.view(n_heads, head_dim, dim) # Print the shape of the reshaped query weight tensor q_layer0.shape #### OUTPUT #### torch.Size([32, 128, 4096]) #### OUTPUT ####
这里,32 是 Llama-3 中的注意力头的数量,128 是查询向量的大小,4096 是标记嵌入的大小。
# Extract the query weight for the first head of the first layer of attention q_layer0_head0 = q_layer0[0] # Print the shape of the extracted query weight tensor for the first head q_layer0_head0.shape #### OUTPUT #### torch.Size([128, 4096]) #### OUTPUT ####
# Matrix multiplication: token embeddings with transpose of query weight for first head q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T) # Shape of resulting tensor: queries per token q_per_token.shape #### OUTPUT #### torch.Size([17, 128]) #### OUTPUT ####
查询向量本身并不知道它们在提示中的位置,所以我们将使用 RoPE 让它们知道这一点。
实施 RoPE
# Convert queries per token to float and split into pairs q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) # Print the shape of the resulting tensor after splitting into pairs q_per_token_split_into_pairs.shape #### OUTPUT #### torch.Size([17, 64, 2]) #### OUTPUT ####
我们有一个大小为 [17x64x2] 的向量,它表示长度为 128 的查询,针对提示中的每个标记分为 64 对。每对将按 m*theta 旋转,其中 m 是我们要旋转查询的标记的位置。
# Generate values from 0 to 1 split into 64 parts zero_to_one_split_into_64_parts = torch.tensor(range(64))/64 # Print the resulting tensor zero_to_one_split_into_64_parts #### OUTPUT #### tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250, 0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656, 0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062, 0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469, 0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875, 0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281, 0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688, 0.9844]) #### OUTPUT ####
# Calculate frequencies using a power operation freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts) # Display the resulting frequencies freqs #### OUTPUT #### tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01, 2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01, 8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02, 2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03, 7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03, 2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04, 6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04, 1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05, 5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05, 1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06, 4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06]) #### OUTPUT ####
# Convert queries per token to complex numbers q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) q_per_token_as_complex_numbers.shape # Output: torch.Size([17, 64]) # Calculate frequencies for each token using outer product of arange(17) and freqs freqs_for_each_token = torch.outer(torch.arange(17), freqs) # Calculate complex numbers from frequencies_for_each_token using polar coordinates freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token) # Rotate complex numbers by frequencies q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis q_per_token_as_complex_numbers_rotated.shape # Output: torch.Size([17, 64])
# Convert rotated complex numbers back to real numbers q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated) # Print the shape of the resulting tensor q_per_token_split_into_pairs_rotated.shape #### OUTPUT #### torch.Size([17, 64, 2]) #### OUTPUT ####
现在合并旋转后的对,从而产生一个新的查询向量(旋转后的查询向量),其形状为 [17×128],其中 17 是标记的数量,128 是查询向量的维度。
# Reshape rotated token queries to match the original shape q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) # Print the shape of the resulting tensor q_per_token_rotated.shape #### OUTPUT #### torch.Size([17, 128]) #### OUTPUT ####
对于键,过程类似,但请记住,键向量也是 128 维的。键的权重数量只有查询的四分之一,因为它们一次在 4 个头之间共享,以最大限度地减少计算量。与查询类似,键也会旋转以包含位置信息。
# Extract the weight tensor for the attention mechanism's key in the first layer of the model k_layer0 = model["layers.0.attention.wk.weight"] # Reshape key weight for the first layer of attention to separate heads k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim) # Print the shape of the reshaped key weight tensor k_layer0.shape # Output: torch.Size([8, 128, 4096]) # Extract the key weight for the first head of the first layer of attention k_layer0_head0 = k_layer0[0] # Print the shape of the extracted key weight tensor for the first head k_layer0_head0.shape # Output: torch.Size([128, 4096]) # Calculate key per token by matrix multiplication k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T) # Print the shape of the resulting tensor representing keys per token k_per_token.shape # Output: torch.Size([17, 128]) # Split key per token into pairs and convert to float k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) # Print the shape of the resulting tensor after splitting into pairs k_per_token_split_into_pairs.shape # Output: torch.Size([17, 64, 2]) # Convert key per token to complex numbers k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) # Print the shape of the resulting tensor representing key per token as complex numbers k_per_token_as_complex_numbers.shape # Output: torch.Size([17, 64]) # Rotate complex key per token by frequencies k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis) # Print the shape of the rotated complex key per token k_per_token_split_into_pairs_rotated.shape # Output: torch.Size([17, 64, 2]) # Reshape rotated key per token to match the original shape k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) # Print the shape of the rotated key per token k_per_token_rotated.shape # Output: torch.Size([17, 128])
我们现在有了每个标记的旋转查询和键,每个大小为 [17×128]。
# Calculate query-key dot products per token qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (head_dim) ** 0.5 # Print the shape of the resulting tensor representing query-key dot products per token qk_per_token.shape #### OUTPUT #### torch.Size([17, 17]) #### OUTPUT ####
[17×17] 形状代表注意力分数(qk_per_token),其中 17 是提示中的标记数。
# Create a mask tensor filled with negative infinity values mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) # Set upper triangular part of the mask tensor to negative infinity mask = torch.triu(mask, diagonal=1) # Print the resulting mask tensor mask #### OUTPUT #### tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]) #### OUTPUT ####
现在,我们必须对每个标记向量的查询键应用掩码。此外,我们想在其上应用 softmax 将输出分数转换为概率。这有助于从模型词汇表中选择最可能的标记或标记序列,使模型的预测更具可解释性,更适合语言生成和分类等任务。
# Add the mask to the query-key dot products per token qk_per_token_after_masking = qk_per_token + mask # Apply softmax along the second dimension after masking qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
对于值矩阵,它标志着自注意力部分的结束,与键类似,值权重也在每 4 个注意力头之间共享,以节省计算量。因此,值权重矩阵的形状为 [8x128x4096]。
# Retrieve the value weight for the first layer of attention v_layer0 = model["layers.0.attention.wv.weight"] # Reshape value weight for the first layer of attention to separate heads v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim) # Print the shape of the reshaped value weight tensor v_layer0.shape #### OUTPUT #### torch.Size([8, 128, 4096]) #### OUTPUT ####
# Extract the value weight for the first head of the first layer of attention v_layer0_head0 = v_layer0[0] # Print the shape of the extracted value weight tensor for the first head v_layer0_head0.shape #### OUTPUT #### torch.Size([128, 4096]) #### OUTPUT ####
使用值权重,我们计算每个标记的注意力值,得到一个大小为 [17×128] 的矩阵。其中,17 表示提示中的标记数,128 表示每个标记的值向量的维度。
# Calculate value per token by matrix multiplication v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T) # Print the shape of the resulting tensor representing values per token v_per_token.shape #### OUTPUT #### torch.Size([17, 128]) #### OUTPUT ####
# Calculate QKV attention by matrix multiplication qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) # Print the shape of the resulting tensor qkv_attention.shape #### OUTPUT #### torch.Size([17, 128]) #### OUTPUT ####
# Store QKV attention for each head in a list qkv_attention_store = [] # Iterate through each head for head in range(n_heads): # Extract query, key, and value weights for the current head q_layer0_head = q_layer0[head] k_layer0_head = k_layer0[head//4] # Key weights are shared across 4 heads v_layer0_head = v_layer0[head//4] # Value weights are shared across 4 heads # Calculate query per token by matrix multiplication q_per_token = torch.matmul(token_embeddings, q_layer0_head.T) # Calculate key per token by matrix multiplication k_per_token = torch.matmul(token_embeddings, k_layer0_head.T) # Calculate value per token by matrix multiplication v_per_token = torch.matmul(token_embeddings, v_layer0_head.T) # Split query per token into pairs and rotate them q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)]) q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) # Split key per token into pairs and rotate them k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)]) k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) # Calculate query-key dot products per token qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5 # Create a mask tensor filled with negative infinity values mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) # Set upper triangular part of the mask tensor to negative infinity mask = torch.triu(mask, diagonal=1) # Add the mask to the query-key dot products per token qk_per_token_after_masking = qk_per_token + mask # Apply softmax along the second dimension after masking qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) # Calculate QKV attention by matrix multiplication qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) # Store QKV attention for the current head qkv_attention_store.append(qkv_attention) # Print the number of QKV attentions stored len(qkv_attention_store) #### OUTPUT #### 32 #### OUTPUT ####
现在已经获得了第一层所有 32 个头的 QKV 注意力矩阵,所有注意力分数将合并为一个大小为 [17×4096] 的大矩阵。
# Concatenate QKV attentions from all heads along the last dimension stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) # Print the shape of the resulting tensor stacked_qkv_attention.shape #### OUTPUT #### torch.Size([17, 4096]) #### OUTPUT ####
第 0 层注意力的最后一步之一是将权重矩阵与堆叠的 QKV 矩阵相乘。
# Calculate the embedding delta by matrix multiplication with the output weight embedding_delta = torch.matmul(stacked_qkv_attention, model["layers.0.attention.wo.weight"].T) # Print the shape of the resulting tensor embedding_delta.shape #### OUTPUT #### torch.Size([17, 4096]) #### OUTPUT ####
# Add the embedding delta to the unnormalized token embeddings to get the final embeddings embedding_after_edit = token_embeddings_unnormalized + embedding_delta # Print the shape of the resulting tensor embedding_after_edit.shape #### OUTPUT #### torch.Size([17, 4096]) #### OUTPUT ####
# Normalize edited embeddings using root mean square normalization and provided weights embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"]) # Print the shape of resulting normalized embeddings embedding_after_edit_normalized.shape #### OUTPUT #### torch.Size([17, 4096]) #### OUTPUT ####
实现 SwiGLU 激活函数
鉴于我们熟悉上一节中的 SwiGLU 激活函数,我们将在这里应用我们之前研究过的方程。
SwiGLU: GLU Variants Improve Transformer (https://kikaben.com/swiglu-2020/)
# Retrieve weights for feedforward layer w1 = model["layers.0.feed_forward.w1.weight"] w2 = model["layers.0.feed_forward.w2.weight"] w3 = model["layers.0.feed_forward.w3.weight"] # Perform operations for feedforward layer output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) # Print the shape of the resulting tensor after feedforward output_after_feedforward.shape #### OUTPUT #### torch.Size([17, 4096]) #### OUTPUT ####
现在一切准备就绪,我们需要合并我们的代码来生成另外 31 层。
# Initialize final embedding with unnormalized token embeddings final_embedding = token_embeddings_unnormalized # Iterate through each layer for layer in range(n_layers): # Initialize list to store QKV attentions for each head qkv_attention_store = [] # Normalize the final embedding using root mean square normalization and weights from the current layer layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"]) # Retrieve query, key, value, and output weights for the attention mechanism of the current layer q_layer = model[f"layers.{layer}.attention.wq.weight"] q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim) k_layer = model[f"layers.{layer}.attention.wk.weight"] k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim) v_layer = model[f"layers.{layer}.attention.wv.weight"] v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim) w_layer = model[f"layers.{layer}.attention.wo.weight"] # Iterate through each head for head in range(n_heads): # Extract query, key, and value weights for the current head q_layer_head = q_layer[head] k_layer_head = k_layer[head//4] # Key weights are shared across 4 heads v_layer_head = v_layer[head//4] # Value weights are shared across 4 heads # Calculate query per token by matrix multiplication q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T) # Calculate key per token by matrix multiplication k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T) # Calculate value per token by matrix multiplication v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T) # Split query per token into pairs and rotate them q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis) q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) # Split key per token into pairs and rotate them k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis) k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) # Calculate query-key dot products per token qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5 # Create a mask tensor filled with negative infinity values mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf")) # Set upper triangular part of the mask tensor to negative infinity mask = torch.triu(mask, diagonal=1) # Add the mask to the query-key dot products per token qk_per_token_after_masking = qk_per_token + mask # Apply softmax along the second dimension after masking qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) # Calculate QKV attention by matrix multiplication qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) # Store QKV attention for the current head qkv_attention_store.append(qkv_attention) # Concatenate QKV attentions from all heads along the last dimension stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) # Calculate embedding delta by matrix multiplication with the output weight embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T) # Add the embedding delta to the current embedding to get the edited embedding embedding_after_edit = final_embedding + embedding_delta # Normalize the edited embedding using root mean square normalization and weights from the current layer embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"]) # Retrieve weights for the feedforward layer w1 = model[f"layers.{layer}.feed_forward.w1.weight"] w2 = model[f"layers.{layer}.feed_forward.w2.weight"] w3 = model[f"layers.{layer}.feed_forward.w3.weight"] # Perform operations for the feedforward layer output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) # Update the final embedding with the edited embedding plus the output from the feedforward layer final_embedding = embedding_after_edit + output_after_feedforward
现在,我们有了最终的嵌入,它代表了模型对下一个标记的猜测。它的形状与常规标记嵌入相同,即 [17×4096],包含 17 个标记,嵌入维度为 4096。
# Normalize the final embedding using root mean square normalization and provided weights final_embedding = rms_norm(final_embedding, model["norm.weight"]) # Print the shape of the resulting normalized final embedding final_embedding.shape #### OUTPUT #### torch.Size([17, 4096]) #### OUTPUT ####
# Print the shape of the output weight tensor model["output.weight"].shape #### OUTPUT #### torch.Size([128256, 4096]) #### OUTPUT ####
# Calculate logits by matrix multiplication between the final embedding and the transpose of the output weight tensor logits = torch.matmul(final_embedding[-1], model["output.weight"].T) # Print the shape of the resulting logits tensor logits.shape #### OUTPUT #### torch.Size([128256]) #### OUTPUT #### # Find the index of the maximum value along the last dimension to determine the next token next_token = torch.argmax(logits, dim=-1) # Output the index of the next token next_token #### OUTPUT #### tensor(2983) #### OUTPUT ####
从 token ID 获取生成的文本
# Decode the index of the next token using the tokenizer tokenizer.decode([next_token.item()]) #### OUTPUT #### 42 #### OUTPUT ####
# input prompt prompt = "Your Input" # Replacing 17 number with total number of tokens in your input # You can check total number of tokens using len(tokens) freqs_for_each_token = torch.outer(torch.arange(17), freqs)