Instruct on Intel Max Series GPUs
- 25 minutes read - 5197 wordsFine-Tuning LLaMa 3 8B Instruct on Intel Max Series GPUs: An Exciting Journey
In this guide, we embark on an exciting journey to fine-tune the powerful LLaMa 3 8B Instruct model using a custom dataset on Intel Max Series GPUs. Intel GPUs offer incredible speed and cost-effectiveness, making them an attractive choice for training large language models.
I successfully trained the LLaMa 3 8B Instruct model using my custom dataset, leveraging the power of HuggingFace. By importing my dataset, obtaining the LLaMa 3 8B Instruct model, and training it using the Transformer Trainer, I was able to achieve outstanding results.
The entire training process was meticulously monitored by Weights and Biases, providing detailed insights into memory usage, disk I/O, training loss, and more. It’s truly an outstanding product!
The best part? It’s all free! I was amazed by the capabilities of the Intel Developer Cloud, particularly for machine learning (ML), high-performance computing (HPC), and generative AI (GenAI) workflows. The Intel® Data Center Max 1100 GPU, a high-performance 300-watt double-wide AIC card, features 56 Xe cores and 48 GB of HBM2E memory, delivering exceptional performance.
In the spirit of open source, I developed the following notebook, partially based on the original work of Rahul Unnikrishnan Nair from Intel, “Fine-tuning Google’s Gemma Model on Intel Max Series GPUs.” Thank you for the foundation. I was able to significantly tweak and create a process to fine-tune LLaMa 3 models, which represent a significant leap from LLaMa 2 in terms of capabilities. It’s amazing how such small models can produce great results with quality data.
I will be producing a detailed video review of the notebook, which you can find on my YouTube channel: YouTube Channel.
Note: This code was executed in a Jupyter notebook on Intel’s Developer Cloud.
Set Path if you receive a path error while installing software
The below sets the path for your .local/bin . The root directory needs to be change to your root directory. ‘/home//.local/bin’ You can find out your userid by launching a terminal executing at the prompt $ pwd
import os
user_dir = '/home/u2b3e96b2fc320ef8c781f51df67225d/'
# Add the directory to the PATH
os.environ['PATH'] += os.pathsep + user_dir + '.local/bin'
# Verify the PATH update
print("Updated PATH:", os.environ['PATH'])
# Check if the directory is now in PATH
if user_dir + '.local/bin' in os.environ['PATH']:
print("Directory successfully added to PATH.")
else:
print("Failed to add directory to PATH.")
Run only once to make sure you have the proper versions of the base software needed.
!python -m pip install torch==2.1.0.post2 torchvision==0.16.0.post2 torchaudio==2.1.0.post2 intel-extension-for-pytorch==2.1.30.post0 oneccl_bind_pt==2.1.300+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
Step 1: Initial Setup
Before we begin the fine-tuning process, it’s essential to ensure that we have the proper libraries installed and the correct kernel configured. This step needs to be performed only once. First, make sure you are using the Modin kernel on the Intel Developer Cloud. The Modin kernel is a specialized kernel designed for efficient data processing and analysis. To access the Modin kernel, follow these steps: Join the Intel Developer Cloud by creating an account. Once logged in, navigate to the “Free Training” section. You will be presented with a Jupyter Lab environment, where you can select the Modin kernel.Run only once to make sure you have the proper versions of the additional software needed.
import sys
import site
import os
# Install the required packages
!{sys.executable} -m pip install --upgrade "transformers>=4.38.*"
!{sys.executable} -m pip install --upgrade "datasets>=2.18.*"
!{sys.executable} -m pip install --upgrade "wandb>=0.17.*"
!{sys.executable} -m pip install --upgrade "trl>=0.7.11"
!{sys.executable} -m pip install --upgrade "peft>=0.9.0"
!{sys.executable} -m pip install --upgrade "accelerate>=0.28.*"
!{sys.executable} -m pip install --upgrade "huggingface_hub"
# Get the site-packages directory
site_packages_dir = site.getsitepackages()[0]
# add the site pkg directory where these pkgs are insalled to the top of sys.path
if not os.access(site_packages_dir, os.W_OK):
user_site_packages_dir = site.getusersitepackages()
if user_site_packages_dir in sys.path:
sys.path.remove(user_site_packages_dir)
sys.path.insert(0, user_site_packages_dir)
else:
if site_packages_dir in sys.path:
sys.path.remove(site_packages_dir)
sys.path.insert(0, site_packages_dir)
Optionally, Check to see if have installed versions of the base software
# Import necessary libraries
import torch
import transformers
import wandb
import trl
import peft
import datasets
# Get versions of the libraries
torch_version = torch.__version__
transformers_version = transformers.__version__
wandb_version = wandb.__version__
trl_version = trl.__version__
peft_version = peft.__version__
datasets_version = datasets.__version__
# Print the versions
print(f"torch version: {torch_version}")
print(f"transformers version: {transformers_version}")
print(f"wandb version: {wandb_version}")
print(f"trl version: {trl_version}")
print(f"peft version: {peft_version}")
print(f"datasets version: {datasets_version}")
Step 2: Check Intel XPU Availability and Retrieve Device Capabilities
In this step, we will import necessary libraries, check the availability of Intel XPU (eXtreme Performance Unit), and retrieve detailed device capabilities. This ensures that our environment is correctly configured to leverage the Intel XPU optimal performance.
To optimize performance when using Intel Max Series GPUs:
- Retrieve CPU Information: Determine the number of physical CPU cores and calculate cores per socket using
psutil
. - Set Environment Variables:
- Disable tokenizers parallelism.
- Improve memory allocation with
LD_PRELOAD
(optional). - Reduce GPU command submission overhead.
- Enable SDP fusion for efficient memory usage.
- Configure OpenMP to use physical cores, bind threads, and set thread pinning.
- Print Configuration: Display the number of physical cores, cores per socket, and OpenMP environment variables to verify the settings.
import os
import intel_extension_for_pytorch as ipex
import warnings
warnings.filterwarnings("ignore")
# Check if Intel XPU is available
if torch.xpu.is_available():
print("Intel XPU is available")
for i in range(torch.xpu.device_count()):
print(f"XPU Device {i}: {torch.xpu.get_device_name(i)}")
# Get the device capability details
device_capability = torch.xpu.get_device_capability()
# Convert the device capability details to a JSON string with indentation for readability
readable_device_capability = json.dumps(device_capability, indent=4)
# Print the readable JSON
print("Detail of GPU capability =\n", readable_device_capability)
else:
print("Intel XPU is not available")
num_physical_cores = psutil.cpu_count(logical=False)
num_cores_per_socket = num_physical_cores // 2
os.environ["TOKENIZERS_PARALLELISM"] = "false"
#HF_TOKEN = os.environ["HF_TOKEN"]
# Set the LD_PRELOAD environment variable
ld_preload = os.environ.get("LD_PRELOAD", "")
conda_prefix = os.environ.get("CONDA_PREFIX", "")
# Improve memory allocation performance, if tcmalloc is not availab>?le, please comment this line out
os.environ["LD_PRELOAD"] = f"{ld_preload}:{conda_prefix}/lib/libtcmalloc.so"
# Reduce the overhead of submitting commands to the GPU
os.environ["SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS"] = "1"
# reducing memory accesses by fusing SDP ops
os.environ["ENABLE_SDP_FUSION"] = "1"
# set openMP threads to number of physical cores
os.environ["OMP_NUM_THREADS"] = str(num_physical_cores)
# Set the thread affinity policy
os.environ["OMP_PROC_BIND"] = "close"
# Set the places for thread pinning
os.environ["OMP_PLACES"] = "cores"
print(f"Number of physical cores: {num_physical_cores}")
print(f"Number of cores per socket: {num_cores_per_socket}")
print(f"OpenMP environment variables:")
print(f" - OMP_NUM_THREADS: {os.environ['OMP_NUM_THREADS']}")
print(f" - OMP_PROC_BIND: {os.environ['OMP_PROC_BIND']}")
print(f" - OMP_PLACES: {os.environ['OMP_PLACES']}")
Step 2: Monitor XPU Memory Usage in Real-Time
The following script sets up a real-time monitoring system that continuously displays the XPU memory usage in a Jupyter notebook, helping you keep track of resource utilization during model training and inference. This setup helps in maintaining optimal performance and preventing resource-related issues during your deep learning tasks. By keeping track of memory usage, you can prevent out-of-memory errors, optimize resource allocation, and ensure smooth training and inference processes. By monitoring these metrics, you can predict out-of-memory issues. If memory usage approaches the hardware limits, it’s an indication that the model or batch size might need adjusted etc.
- Memory Reserved: Indicates the total memory reserved by the XPU. Helps in understanding the memory footprint of the running processes.
- Memory Allocated: Shows the actual memory usage by tensors, crucial for identifying memory leaks or excessive usage.
- Max Memory Reserved/Allocated: These metrics help in identifying peak memory usage, which is essential for planning and scaling your models.
- performance and preventing resource-related issues during your deep learning tasks.eemory_monitor(output)
import psutil
import torch
import json
import asyncio
import threading
from IPython.display import display, HTML
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
torch.xpu.empty_cache()
def get_memory_usage():
memory_reserved = round(torch.xpu.memory_reserved() / 1024**3, 3)
memory_allocated = round(torch.xpu.memory_allocated() / 1024**3, 3)
max_memory_reserved = round(torch.xpu.max_memory_reserved() / 1024**3, 3)
max_memory_allocated = round(torch.xpu.max_memory_allocated() / 1024**3, 3)
return memory_reserved, memory_allocated, max_memory_reserved, max_memory_allocated
def print_memory_usage():
device_name = torch.xpu.get_device_name()
print(f"XPU Name: {device_name}")
memory_reserved, memory_allocated, max_memory_reserved, max_memory_allocated = get_memory_usage()
memory_usage_text = f"XPU Memory: Reserved={memory_reserved} GB, Allocated={memory_allocated} GB, Max Reserved={max_memory_reserved} GB, Max Allocated={max_memory_allocated} GB"
print(f"\r{memory_usage_text}", end="", flush=True)
async def display_memory_usage(output):
device_name = torch.xpu.get_device_name()
output.update(HTML(f"<p>XPU Name: {device_name}</p>"))
while True:
memory_reserved, memory_allocated, max_memory_reserved, max_memory_allocated = get_memory_usage()
memory_usage_text = f"XPU ({device_name}) :: Memory: Reserved={memory_reserved} GB, Allocated={memory_allocated} GB, Max Reserved={max_memory_reserved} GB, Max Allocated={max_memory_allocated} GB"
output.update(HTML(f"<p>{memory_usage_text}</p>"))
await asyncio.sleep(5)
def start_memory_monitor(output):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(display_memory_usage(output))
thread = threading.Thread(target=loop.run_forever)
thread.start()
output = display(display_id=True)
start_memory_monitor(output)
else:
print("XPU device not available.")
Step 3: Load and Prepare the Model
In this step, we ensure the model is loaded and prepared for use on the appropriate device, either an Intel XPU or CPU, and configure it for efficient fine-tuningThis ensures the model and tokenizer are properly set up and optimized for use on the selected device, ready for efficient fine-tuning. This step ensures that the model and tokenizer are correctly set up and configured for use on the appropriate device, preparing them for the fine-tuning process. .
Check Device Availability:
- Check if an XPU is available and set the device accordingly. If the XPU is available and
USE_CPU
is not set toTrue
, use the XPU; otherwise, use the CPU.
- Check if an XPU is available and set the device accordingly. If the XPU is available and
Specify Model Name:
- Define the model name to be used.
Download Model if Not Existing Locally:
- Define a function to check if the model exists locally.
- If the model does not exist locally, download it from the specified model name, save the tokenizer and model locally.
Load Model and Tokenizer:
- Load the model and tokenizer from the local directory where they were saved.
- Set the padding token and padding side for the tokenizer.
- Resize the model’s embeddings to account for any new special tokens added.
- Set the padding token ID in the model’s generation configuration.
Move Model to Device:
- Move the model to the appropriate device (XPU or CPU).
Configure Model for Fine-Tuning:
- Disable the caching mechanism to reduce memory usage during fine-tuning.
- Configure the model’s pre-training teigured for use on the appropriate device, preparing them for the fine-tuning process.
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Check if XPU is available and set the device accordingly
USE_CPU = False
device = "xpu:0" if torch.xpu.is_available() and not USE_CPU else "cpu"
print(f"Using device: {device}")
# Specify the model name
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# Check if the tokenizer has a padding token
if tokenizer.pad_token is None:
print("Adding padding token to tokenizer.")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
else:
print(f"Padding token already exists: {tokenizer.pad_token}")
# # Set the padding token and padding side
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
tokenizer.padding_side = "right"
# # # Resize the model embeddings to account for the new special tokens
model.resize_token_embeddings(len(tokenizer))
# # Debugging statements
print(f"Padding token: {tokenizer.pad_token}")
print(f"Padding token ID: {tokenizer.pad_token_id}")
# Set the padding token ID for the generation configuration
model.generation_config.pad_token_id = tokenizer.pad_token_id
# Move the model to the appropriate device
model.to(device)
# Disable caching mechanism to reduce memory usage during fine-tuning
model.config.use_cache = False
# Configure the model's pre-training tensor parallelism degree
model.config.pretraining_tp = 1
print("Model and tokenizer are ready for use.")
Step 4 Log into your hugging face account with your access token.
Uncheck the Add token as git credential! 🎛️
#loggin to huggnigface
from huggingface_hub import notebook_login
notebook_login()
Step 5: Load and Inspect the Dataset 📊
Import the load_dataset function and load the specified dataset from the Hugging Face datasets library. In this case, the dataset identifier is RayBernard/nvidia-dgx-best-practices, and we are loading the training split of the dataset. Print the first instruction and response from the dataset to ensure the content is as expected. Next, print the total number of examples in the dataset to understand its size. List the fields (keys) present in the dataset to understand its structure.
Format and Split the Dataset for Training
This step ensures your dataset is properly formatted and split for the training process, making it ready for fine-tuning.
Load and Define:
- Load the dataset with the specified name and split. Here, we are loading the “train” split of the dataset.
- Define the system message to be used for formatting prompts.
Format Prompts:
- Use the
format_prompts
function to format the dataset prompts according to the Meta Llama 3 Instruct prompt template with special tokens. - This function iterates over the ‘instruction’ and ‘output’ fields in the batch and formats them accordingly.
- Apply the
format_prompts
function to the dataset in a batched manner for efficiency.
- Use the
Split the Dataset:
- Split the formatted dataset into training and validation sets, using 20% of the data for validation and setting a seed for reproducibility.
Verify the Split:
- Print the number of examples in both the training and validation sets to verify the split.
Show Formatted Prompt:
- Define and use a function to show the formatted prompt for the first record, demonstrating what the prompt looks like with the system message included.
This process ensures that your dataset is well-organized and ready for the training phase, enhancing the model’s performance during fine-tuning.d contents.
from datasets import load_dataset
# Load a specific dataset from the Hugging Face datasets library.
# 'RayBernard/nvidia-dgx-best-practices' is the identifier for the dataset,
# and 'split="train"' specifies that we want the training split of the dataset.
dataset_name = "RayBernard/nvidia-dgx-best-practices"
dataset = load_dataset(dataset_name, split="train")
# Print the first instruction and response from the dataset to verify the content.
print(f"Instruction is: {dataset[0]['instruction']}")
print(f"Response is: {dataset[0]['output']}")
# Print the number of examples in the dataset.
print(f"Number of examples in the dataset: {len(dataset)}")
# Print the fields (keys) present in the dataset.
print(f"Fields in the dataset: {list(dataset.features.keys())}")
# Print the entire dataset to get an overview of its structure and contents.
print(dataset)
# Load the dataset with the specified name and split
# Here, we are loading the "train" split of the dataset
dataset = load_dataset(dataset_name, split="train")
# Define the system message separately
# system_message = "Respond with the appropriate command only"
system_message = "You are a helpful AI "
def format_prompts(batch, system_msg):
"""
Format the prompts according to the Meta Llama 3 Instruct prompt template with special tokens.
Args:
batch (dict): A batch of data containing 'instruction' and 'output' fields.
system_msg (str): The system message to be included in the prompt.
Returns:
dict: A dictionary containing the formatted prompts under the 'text' key.
"""
# Initialize an empty list to store the formatted prompts
formatted_prompts = []
# Iterate over the 'instruction' and 'output' fields in the batch
for instruction, output in zip(batch["instruction"], batch["output"]):
# Format the prompt according to the Meta Llama 3 Instruct template with special tokens
prompt = (
"<|startoftext|>system\n"
f"{system_msg}\n"
"<|endoftext|>user\n"
f"{instruction}\n"
"<|endoftext|>assistant\n"
f"{output}\n"
"<|endoftext|>"
)
# Append the formatted prompt to the list
formatted_prompts.append(prompt)
# Return the formatted prompts as a dictionary with the key 'text'
return {"text": formatted_prompts}
# Apply the format_prompts function to the dataset
# The function is applied in a batched manner to speed up processing
formatted_dataset = dataset.map(lambda batch: format_prompts(batch, system_message), batched=True)
# Split the dataset into training and validation sets
# 20% of the data is used for validation, and a seed is set for reproducibility
split_dataset = formatted_dataset.train_test_split(test_size=0.2, seed=99)
train_dataset = split_dataset["train"]
validation_dataset = split_dataset["test"]
print("train dataset == ",train_dataset)
print("validation dataset ==", validation_dataset)
# Print the number of examples in the training and validation sets
print(f"Number of examples in the training set: {len(train_dataset)}")
print(f"Number of examples in the validation set: {len(validation_dataset)}")
# Function to show what the prompt looks like for the first record with the system message
def show_first_prompt(system_msg):
# Get the first record from the dataset
first_instruction = dataset["instruction"][0]
first_output = dataset["output"][0]
# Format the first record using the provided system message
prompt = (
"<|startoftext|>system\n"
f"{system_msg}\n"
"<|endoftext|>user\n"
f"{first_instruction}\n"
"<|endoftext|>assistant\n"
f"{first_output}\n"
"<|endoftext|>"
)
# Print the original instruction and output
print(f"Original instruction: {first_instruction}")
print(f"Original output: {first_output}")
# Print the formatted prompt
print(f"\nFormatted prompt with system message:\n{prompt}")
# Show what the prompt looks like for the first record with the system message
show_first_prompt(system_message)
Step 6: Fine-Tune the Model and Save the Results
- Setup Imports and Configurations:
In this step, we configure the LoRA (Low-Rank Adaptation) settings for efficient training of our model. LoRA is a technique that improves the efficiency of training by reducing the number of parameters through low-rank decomposition. Here, we instantiate a LoraConfig object with specific parameters tailored to our training needs.
Instantiate LoRA Configuration:
- r: Set to 64, this parameter controls the dimension of the low-rank decomposition, balancing model capacity and efficiency.
- lora_alpha: Set to 16, this scaling factor adjusts the output of the low-rank decomposition, influencing the strength of the adaptation.
- lora_dropout: Set to 0.5, this dropout rate applies regularization to the LoRA layers to prevent overfitting. A higher value increases regularization.
- bias: Set to "none", indicating no bias is added to the LoRA layers.
- target_modules: Specifies the layers where the low-rank adaptation will be applied. Here, it includes "q_proj", "k_proj", "v_proj", and "output_proj".
- task_type: Set to "CAUSAL_LM", indicating that this configuration is for a causal language modeling task.
- This configuration optimizes the model's training efficiency and performance by carefully adjusting the parameters and specifying the target modules for low-rank adaptation.
Set Environment Variables:
- Configure relevant environment variables for logging and configuration, including Weights and Biases project settings.
Load Datasets:
- Load the training and validation datasets.
Configure Training Parameters:
- Set training parameters including batch size, gradient accumulation steps, learning rate, and mixed precision training.
Initialize Trainer:
- Initialize the
SFTTrainer
with LoRA configuration, including training arguments and datasets.
- Initialize the
Optimize Performance:
- Clear the XPU cache before starting the training process.
Begin Training:
- Start the training process.
- Print a summary of the training results, including total training time and samples processed per second.
- Handle any exceptions to ensure smooth execution.
Save the Model:
- Save the fine-tuned LoRA model to the specified path for future use.
This step-by-step approach ensures that the model is properly fine-tuned and ready for deployment, with optimal performance configurations and comprehensive logging for tracking with Weights and Bias.
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig
import wandb
# Configuration variables
PUSH_TO_HUB = True # Flag to determine if the model should be pushed to Hugging Face Hub
USE_WANDB = True # Flag to determine if Weights and Biases (WandB) should be used for logging
# Unset the LD_PRELOAD environment variable if it exists
os.environ.pop('LD_PRELOAD', None)
# LoRA (Low-Rank Adaptation) configuration for model fine-tuning
lora_config = LoraConfig(
r=64, # Low-rank adaptation rank, best practice is between 32-64
# Increase (e.g., 128) for better precision but higher memory use
lora_alpha=32, # Scaling factor for the LoRA weights, best practice is 16-64
# Increase (e.g., 64) for more stable training but higher memory use
lora_dropout=0.3, # Dropout probability for LoRA layers, typical range is 0.1-0.3
# Decrease (e.g., 0.1) for better precision, increase (e.g., 0.5) for regularization
bias="all", # Apply bias to all layers
target_modules=["q_proj", "v_proj", "output_proj"], # Specific modules in the model to apply LoRA
task_type="CAUSAL_LM" # Task type: Causal Language Modeling, where the model predicts the next word in a sequence
)
# ID of the fine-tuned model to be pushed to Hugging Face Hub
finetuned_model_id = "RayBernard/llama-3-8B-Instruct-ft"
# Set TOKENIZERS_PARALLELISM environment variable to avoid parallelism warning during tokenization
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Set other environment variables for WandB logging and configuration
os.environ['WANDB_NOTEBOOK_NAME'] = 'llama3-8B-FT-Intel-XPU.0.0.1.ipynb' # Name of the notebook for WandB logging
os.environ["WANDB_PROJECT"] = "llama3-8b-Instruct-ft" # WandB project name
os.environ["WANDB_LOG_MODEL"] = "checkpoint" # WandB log model checkpoint
os.environ["IPEX_TILE_AS_DEVICE"] = "1" # Intel XPU configuration
# Training configuration
num_train_samples = len(train_dataset) # Number of training samples in the dataset
batch_size = 2 # Number of training examples utilized in one forward/backward pass, typical range is 2-16
# Increase (e.g., 4) for better memory utilization if sufficient GPU memory is available
gradient_accumulation_steps = 16 # Number of steps to accumulate gradients before performing an optimizer step
# Decrease (e.g., 8) to use more GPU memory per step, increase for less memory use
steps_per_epoch = num_train_samples // (batch_size * gradient_accumulation_steps) # Number of steps per epoch
num_epochs = 25 # Total number of passes through the entire training dataset, typical range is 3-30
max_steps = steps_per_epoch * num_epochs # Total number of training steps
print(f"Finetuning for max number of steps: {max_steps}") # Print the total number of training steps
def print_training_summary(results):
"""
Print a summary of the training results.
Args:
results: Training results object containing metrics.
"""
print(f"Time: {results.metrics['train_runtime']: .2f}")
print(f"Samples/second: {results.metrics['train_samples_per_second']: .2f}")
# Configuration for the supervised fine-tuning (SFT) trainer
training_args = SFTConfig(
run_name="llama3-8b-finetuning2", # Unique name for this training run
per_device_train_batch_size=batch_size, # Batch size per device during training
gradient_accumulation_steps=gradient_accumulation_steps, # Steps to accumulate gradients before updating model parameters
warmup_ratio=0.1, # Ratio of total steps used for a linear warm-up of the learning rate, typical range is 0.06-0.2
max_steps=max_steps, # Maximum number of training steps to perform
learning_rate=2e-5, # Learning rate for the optimizer, typical range is 1e-5 to 5e-5
# Decrease (e.g., 1e-5) for better precision, increase (e.g., 3e-5) for faster convergence
lr_scheduler_type="cosine", # Learning rate scheduler type: cosine annealing
evaluation_strategy="steps", # Evaluation strategy: evaluate the model every few steps
save_steps=500, # Save a checkpoint of the model every 500 steps
fp16=True, # Use 16-bit (half precision) floating point arithmetic to speed up training and reduce memory usage
logging_steps=100, # Log training metrics every 100 steps
output_dir=finetuned_model_id, # Directory where the model checkpoints will be saved
hub_model_id=finetuned_model_id if PUSH_TO_HUB else None, # Model ID for pushing to Hugging Face Hub
report_to="wandb" if USE_WANDB else None, # Report metrics to WandB if enabled
push_to_hub=PUSH_TO_HUB, # Flag to push the model to Hugging Face Hub
max_grad_norm=0.6, # Maximum gradient norm for gradient clipping, typical range is 0.1-1.0
weight_decay=0.01, # Weight decay coefficient for regularization, typical range is 0.01-0.1
# Increase (e.g., 0.1) for more regularization to prevent overfitting, decrease for less
group_by_length=True, # Group sequences of similar length for efficient training
gradient_checkpointing=True # Enable gradient checkpointing to save memory at the cost of some compute overhead
)
# Initialize the SFT trainer with the model, training arguments, datasets, and tokenizer
trainer = SFTTrainer(
model=model, # Model to be fine-tuned
args=training_args, # Training arguments
train_dataset=train_dataset, # Training dataset
eval_dataset=validation_dataset, # Validation dataset
tokenizer=tokenizer, # Tokenizer
peft_config=lora_config, # LoRA configuration
dataset_text_field="text", # Text field in the dataset
max_seq_length=512, # Maximum sequence length for tokenization, typical values are 512-1024
# Increase (e.g., 1024) for handling longer contexts but higher memory use
packing=True # Enable packing of sequences to make use of available space in a batch more efficiently
)
# Try to train the model and handle any exceptions
try:
torch.xpu.empty_cache() # Clear the cache of the Intel XPU
results = trainer.train() # Train the model
print_training_summary(results) # Print a summary of the training results
wandb.finish() # Finish the WandB run
except Exception as e:
print(f"Error during training: {e}") # Print any errors that occur during training
# Save the fine-tuned model
tuned_lora_model = "llama3-8b-Instruct-ft-lora" # Directory name for saving the model
trainer.model.save_pretrained(tuned_lora_model) # Save the model to the specified directory
Step 12: Merge and Save the Fine-Tuned Model
After fine-tuning the model, merge the fine-tuned LoRA model with the base model and save the final tuned model. This process ensures that the fine-tuning adjustments are integrated into the base model, resulting in an optimized and ready-to-use model.
- Import Required Libraries: Import the necessary libraries from
peft
andtransformers
. - Load Base Model: Load the base model using
AutoModelForCausalLM
with the specified model ID and configurations to optimize memory usage and performance. - Merge Models: Use
PeftModel
to load the fine-tuned LoRA model and merge it with the base model. - Unload Unnecessary Parameters: Merge and unload unnecessary parameters from the model to optimize it.
- Save the Final Model: Save the final merged model to the specified path for future use.
This step finalizes the training process by producing a single, fine-tuned model-ready inferencing.
# Import the necessary libraries
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Define the path to the fine-tuned LoRA model and the base model ID
tuned_lora_model = "llama3-8b-Instruct-ft-lora"
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# Load the base model using the specified model ID.
# The parameters used are:
# - low_cpu_mem_usage: Reduces memory usage on the CPU.
# - return_dict: Ensures the model returns outputs as a dictionary.
# - torch_dtype: Specifies the data type as bfloat16 for efficient computation.
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.bfloat16,
)
# Load the tokenizer using the same model ID
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Check if the tokenizer has a padding token
if tokenizer.pad_token is None:
print("Adding padding token to tokenizer.")
# Add a special padding token to the tokenizer if it doesn't have one
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Resize the token embeddings of the base model to match the updated tokenizer
base_model.resize_token_embeddings(len(tokenizer))
else:
print(f"Padding token already exists: {tokenizer.pad_token}")
# Ensure the padding token is set (use the existing one or the EOS token if none exists)
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
# Set the padding side to 'right', meaning that padding tokens will be added to the end of sequences
tokenizer.padding_side = "right"
# Load the PEFT (Parameter-Efficient Fine-Tuning) model with the pre-trained base model and the LoRA-tuned model
model = PeftModel.from_pretrained(base_model, tuned_lora_model)
# Merge the LoRA parameters into the base model and unload the PEFT structure
model = model.merge_and_unload()
# Define the path where the final tuned model and tokenizer will be saved
final_model_path = "final-tuned-model" # Replace with your desired path
# Save the final tuned model to the specified path
model.save_pretrained(final_model_path)
# Save the tokenizer to the specified path
tokenizer.save_pretrained(final_model_path)
print("Final tuned model and tokenizer saved successfully.")
Optional Upload your model to hugging face hub
from huggingface_hub import HfApi, upload_folder, login
import os
final_model_path = "final-tuned-model" # Path where the model and tokenizer are saved
repo_name = "llama3-8b-Instruct-finetuned" # Name of the repository on Hugging Face Hub
username = "RayBernard" # Your username
# # Log in to your Hugging Face account and get the token
# login()
# Retrieve the token from the cache
with open(os.path.expanduser("~/.cache/huggingface/token"), "r") as token_file:
token = token_file.read().strip()
# Create a new repository or use an existing one
api = HfApi()
api.create_repo(repo_id=repo_name, token=token, exist_ok=True)
# Upload the entire model directory to the repository
upload_folder(
folder_path=final_model_path,
repo_id=f"{username}/{repo_name}",
token=token,
repo_type="model"
)
print(f"Model and tokenizer uploaded to Hugging Face Hub repository: {repo_name}")
Test Model without fine tunning ** Note at this point you should restart the kernel and clear so resources are freed up
# Import necessary libraries
import transformers
import torch
# Define the model ID for the base model
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# Create a text-generation pipeline using the specified model ID
pipeline = transformers.pipeline(
"text-generation", # Task type is text generation
model=model_id, # Use the specified model
model_kwargs={"torch_dtype": torch.bfloat16}, # Model parameters: use bfloat16 for efficient computation
device_map="auto", # Automatically map model to available devices (e.g., GPU if available)
)
# Define a list of messages for the text-generation pipeline
messages_list = [
[
{"role": "system", "content": "Just respond with the command"}, # System message
{"role": "user", "content": "how many gpu are in an h100?"}, # User query
],
[
{"role": "system", "content": "Just respond with the command"}, # System message
{"role": "user", "content": "how many gpu are in an h200?"}, # User query
],
[
{"role": "system", "content": "Just respond with the command"}, # System message
{"role": "user", "content": "what kind of switch os to run InfiniBand network"}, # User query
],
]
# Define terminators for stopping the generation
terminators = [
pipeline.tokenizer.eos_token_id, # End of sequence token ID
pipeline.tokenizer.convert_tokens_to_ids("") # Convert an empty string to a token ID (might be unnecessary)
]
# Loop through the list of messages
for messages in messages_list:
outputs = pipeline(
messages, # Input messages to the pipeline
max_new_tokens=256, # Generate up to 256 new tokens
eos_token_id=terminators, # Use the defined terminators to stop generation
do_sample=True, # Enable sampling for text generation
temperature=0.6, # Sampling temperature: lower value makes output more deterministic
top_p=0.9, # Top-p (nucleus) sampling: consider the top 90% probability mass
)
print(outputs[0]["generated_text"][-1]) # Print the last token of the generated text
Now test the model after fine-tuning
import transformers
import torch
# Define the model ID for the fine-tuned model
model_id = "final-tuned-model"
# Create a text-generation pipeline using the fine-tuned model ID
pipeline = transformers.pipeline(
"text-generation", # Task type is text generation
model=model_id, # Use the fine-tuned model
model_kwargs={"torch_dtype": torch.bfloat16}, # Model parameters: use bfloat16 for efficient computation
device_map="auto", # Automatically map model to available devices (e.g., GPU if available)
)
# Define the same list of messages for testing the fine-tuned model
messages_list = [
[
{"role": "system", "content": "Just respond with the command"}, # System message
{"role": "user", "content": "how many gpu are in an h100?"}, # User query
],
[
{"role": "system", "content": "Just respond with the command"}, # System message
{"role": "user", "content": "how many gpu are in an h200?"}, # User query
],
[
{"role": "system", "content": "Just respond with the command"}, # System message
{"role": "user", "content": "what kind of switch os to run InfiniBand network"}, # User query
],
]
# Define terminators for stopping the generation
terminators = [
pipeline.tokenizer.eos_token_id, # End of sequence token ID
pipeline.tokenizer.convert_tokens_to_ids("") # Convert an empty string to a token ID (might be unnecessary)
]
# Loop through the list of messages
for messages in messages_list:
outputs = pipeline(
messages, # Input messages to the pipeline
max_new_tokens=256, # Generate up to 256 new tokens
eos_token_id=terminators, # Use the defined terminators to stop generation
do_sample=True, # Enable sampling for text generation
temperature=0.1, # Sampling temperature: lower value makes output more deterministic
top_p=0.9, # Top-p (nucleus) sampling: consider the top 90% probability mass
)
print(outputs[0]["generated_text"][-1]) # Print the last token of the generated text
Happy Fine-Tuning! 😄✨
Congratulations on reaching this exciting milestone! You now possess the tools and knowledge to fine-tune the powerful LLaMA 3 model on your own custom datasets. This achievement opens up a world of possibilities for you to explore and unleash the full potential of this cutting-edge language model. We encourage you to embrace the spirit of experimentation and exploration. Feel free to customize and adapt this notebook to fit your specific use case. Try different datasets, tweak the hyperparameters, and observe how the model’s performance evolves. This hands-on experience will deepen your understanding and allow you to tailor the model to your unique requirements. Moreover, we invite you to share your fine-tuned models and experiences with the broader community. Consider open-sourcing your work on platforms like GitHub or Hugging Face, and write blog posts to detail your fine-tuning journey. Your insights and achievements can inspire and assist others who are embarking on their own fine-tuning projects, fostering a collaborative and supportive environment for knowledge sharing. If you encounter any challenges or have suggestions for improvement, please don’t hesitate to reach out and provide feedback. We value your input and are committed to making this notebook and the fine-tuning process as smooth and enjoyable as possible. Your feedback will help us refine and enhance the resources available to the community. Remember, the journey of fine-tuning language models is an iterative and continuous process. Embrace the challenges, celebrate your successes, and continue pushing the boundaries of what’s possible. Together, we can unlock the full potential of these powerful models and drive innovation in various domains.
Star my repo please
Download
You can download the Intel Notebook by clicking the link below:
Download llama3-8B-FT-Intel-XPU.0.0.3.ipynbComments
SPDX-License-Identifier: Apache-2.0 Copyright (c) 2024, Rahul Unnikrishnan Nair rahul.unnikrishnan.nair@intel.com Ray Bernard ray.bernard@outlookcom