This content originally appeared on DEV Community and was authored by Prashant Nigam
This is where the magic happens! In this part, we will deep dive into LoRA (Low-Rank Adaptation) fine-tuning and use MLX to train our model with incredible efficiency on Apple Silicon.
Understanding LoRA: The Game-Changing Technique
Imagine you are a master chef who wants to learn a new cuisine. Instead of forgetting everything you know and starting from scratch, you add new techniques and flavor profiles to your existing knowledge. That’s exactly what LoRA (Low-Rank Adaptation) does for language models.
The Traditional Fine-Tuning Problem
Traditional fine-tuning updates all 1.7 billion parameters of our model. This means:
Massive memory requirements
Slow training
Risk of “catastrophic forgetting” (losing general knowledge)
Large model files
The LoRA Solution
LoRA adds small “adapter” layers that learn new behaviors while keeping the original model frozen:
Minimal memory usage
Fast training
Preserves general knowledge
Tiny adapter file size
Can be combined or switched out easily
How LoRA Works Under the Hood
Think of the original model as a Swiss Army knife with all its tools welded in place. LoRA adds new attachments that can be snapped on or off.
MLX: Apple’s Secret Weapon for AI
MLX is Apple’s machine learning framework designed specifically for Apple Silicon. It’s what makes our local fine-tuning possible and incredibly fast.
Why MLX is good for Local AI
- Unified Memory Architecture: M-series chips share memory between CPU and GPU, eliminating data transfer bottlenecks
- Optimized Computation: Hand-tuned for Apple Silicon’s specific capabilities
- Memory Efficiency: Intelligent memory management for maximum model sizes
- Python Integration: Easy to use while being incredibly fast
Setting Up Our Fine-Tuning Pipeline
Let us build our fine-tuning system step by step, understanding each component.
Step 1: Configuration and Setup
First, let’s create a comprehensive configuration system:
touch fine_tuning_config.py
# Create fine_tuning_config.py
import os
from pathlib import Path
import mlx.core as mx
class FineTuningConfig:
"""Centralized configuration for fine-tuning"""
def __init__(self):
# Model configuration
self.base_model = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
self.adapter_path = "./adapters/email_sentiment"
# Data paths
self.train_data_path = "./data/mlx_format/train.jsonl"
self.valid_data_path = "./data/mlx_format/valid.jsonl"
# LoRA parameters
self.lora_layers = 16 # Number of transformer layers to add LoRA to
self.lora_rank = 16 # The 'r' in LoRA - higher = more capacity but slower
self.lora_alpha = 32 # Scaling factor for LoRA adapters
# Training parameters
self.batch_size = 2 # Batch size (reduce if out of memory)
self.learning_rate = 5e-5 # Learning rate
self.max_iters = 1000 # Maximum training iterations
self.steps_per_report = 10 # How often to print progress
self.steps_per_eval = 200 # How often to run validation
self.save_every = 400 # How often to save checkpoints
# Hardware optimization
self.use_gpu = mx.metal.is_available()
self.max_sequence_length = 2048
# Create directories
Path(self.adapter_path).mkdir(parents=True, exist_ok=True)
def print_config(self):
"""Print current configuration"""
print("🔧 Fine-tuning Configuration:")
print(f" Base model: {self.base_model}")
print(f" GPU available: {self.use_gpu}")
print(f" LoRA rank: {self.lora_rank}")
print(f" LoRA layers: {self.lora_layers}")
print(f" Batch size: {self.batch_size}")
print(f" Learning rate: {self.learning_rate}")
print(f" Max iterations: {self.max_iters}")
print(f" Adapter path: {self.adapter_path}")
# Create and test config
if __name__ == "__main__":
config = FineTuningConfig()
config.print_config()
Step 2: Memory and Performance Monitoring
Before we start fine-tuning, let’s create tools to monitor our system:
touch monitoring.py
# Create monitoring.py
import time
import mlx.core as mx
from typing import Dict, List
import psutil
class PerformanceMonitor:
"""Monitor memory usage and training performance"""
def __init__(self):
self.start_time = time.time()
self.metrics = []
def log_memory_usage(self, step: int, loss: float = None):
"""Log current memory and performance metrics"""
# GPU memory (if available)
gpu_memory = {}
if mx.metal.is_available():
gpu_memory = {
'active_mb': mx.metal.get_active_memory() / 1e6,
'peak_mb': mx.metal.get_peak_memory() / 1e6
}
# System memory
system_memory = psutil.virtual_memory()
# Training metrics
elapsed = time.time() - self.start_time
metrics = {
'step': step,
'elapsed_seconds': elapsed,
'loss': loss,
'gpu_active_mb': gpu_memory.get('active_mb', 0),
'gpu_peak_mb': gpu_memory.get('peak_mb', 0),
'system_memory_percent': system_memory.percent,
'system_memory_available_gb': system_memory.available / 1e9
}
self.metrics.append(metrics)
if step % 50 == 0: # Print every 50 steps
self.print_status(metrics)
return metrics
def print_status(self, metrics: Dict):
"""Print current training status"""
print(f"Step {metrics['step']:4d} | "
f"Loss: {metrics['loss']:.4f} | "
f"GPU: {metrics['gpu_active_mb']:.0f}MB | "
f"Time: {metrics['elapsed_seconds']:.1f}s")
def get_training_summary(self):
"""Get summary of training run"""
if not self.metrics:
return {}
peak_gpu = max(m['gpu_peak_mb'] for m in self.metrics)
total_time = self.metrics[-1]['elapsed_seconds']
final_loss = self.metrics[-1]['loss']
return {
'total_training_time': total_time,
'peak_gpu_memory_mb': peak_gpu,
'final_loss': final_loss,
'steps_completed': len(self.metrics)
}
Step 3: The Fine-Tuning Engine
Now let’s create our main fine-tuning script using MLX-LM:
touch fine_tune_model.py
# Create fine_tune_model.py
import subprocess
import time
import json
import os
from pathlib import Path
from fine_tuning_config import FineTuningConfig
from monitoring import PerformanceMonitor
class MLXFineTuner:
"""Fine-tune models using MLX with LoRA"""
def __init__(self, config: FineTuningConfig):
self.config = config
self.monitor = PerformanceMonitor()
def validate_data(self):
"""Validate that training data exists and is properly formatted"""
print("📊 Validating training data...")
if not os.path.exists(self.config.train_data_path):
raise FileNotFoundError(f"Training data not found: {self.config.train_data_path}")
# Count training examples
train_count = 0
with open(self.config.train_data_path, 'r') as f:
for line in f:
if line.strip():
train_count += 1
print(f"✅ Found {train_count} training examples")
# Validate format
with open(self.config.train_data_path, 'r') as f:
first_line = f.readline()
try:
example = json.loads(first_line)
if 'text' not in example:
raise ValueError("Training data must have 'text' field")
print("✅ Data format validated")
except json.JSONDecodeError:
raise ValueError("Training data must be valid JSONL format")
return train_count
def build_training_command(self):
"""Build the MLX-LM training command"""
cmd = [
"python3", "-m", "mlx_lm", "lora",
"--model", self.config.base_model,
"--train",
"--data", "./data/mlx_format", # Directory containing train.jsonl
"--batch-size", str(self.config.batch_size),
"--iters", str(self.config.max_iters),
"--learning-rate", str(self.config.learning_rate),
"--steps-per-report", str(self.config.steps_per_report),
"--steps-per-eval", str(self.config.steps_per_eval),
"--adapter-path", self.config.adapter_path,
"--save-every", str(self.config.save_every)
]
return cmd
def run_fine_tuning(self):
"""Execute the fine-tuning process"""
print("🚀 Starting LoRA fine-tuning with MLX...")
print("=" * 60)
# Validate everything is ready
train_count = self.validate_data()
self.config.print_config()
# Build command
cmd = self.build_training_command()
print(f"\n📝 Command: {' '.join(cmd)}")
# Start training
start_time = time.time()
print(f"\n🏃 Training started at {time.strftime('%H:%M:%S')}")
print(f"📚 Training on {train_count} examples")
print("💡 This typically takes 3-10 minutes on Apple Silicon M3")
print("⏰ Progress will be reported every 10 steps\n")
try:
# Run the training command
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
training_time = time.time() - start_time
print("\n" + "="*60)
print("🎉 Fine-tuning completed successfully!")
print(f"⏱ Total training time: {training_time:.1f} seconds")
print(f"💾 Adapters saved to: {self.config.adapter_path}")
# Save training metadata
metadata = {
'model_name': self.config.base_model,
'training_time_seconds': training_time,
'training_examples': train_count,
'lora_rank': self.config.lora_rank,
'lora_layers': self.config.lora_layers,
'batch_size': self.config.batch_size,
'learning_rate': self.config.learning_rate,
'max_iters': self.config.max_iters,
'timestamp': time.time(),
'command_used': ' '.join(cmd)
}
metadata_path = f"{self.config.adapter_path}/training_metadata.json"
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
print(f"📊 Training metadata saved to: {metadata_path}")
# Parse and display training output
self.parse_training_output(result.stdout)
return True, metadata
except subprocess.CalledProcessError as e:
print("\n❌ Fine-tuning failed!")
print(f"Error code: {e.returncode}")
print(f"Error output: {e.stderr}")
print(f"Standard output: {e.stdout}")
return False, None
def parse_training_output(self, output: str):
"""Parse and display key information from training output"""
print("\n📈 Training Progress Summary:")
print("-" * 40)
lines = output.split('\n')
# Look for key training metrics
for line in lines:
if 'Loss:' in line or 'Validation' in line:
print(f" {line.strip()}")
# Look for final metrics
for line in reversed(lines):
if 'Loss:' in line:
print(f"\n🎯 Final training loss: {line.split('Loss:')[-1].strip()}")
break
def verify_training_output(self):
"""Verify that training produced the expected files"""
print("\n🔍 Verifying training output...")
adapter_path = Path(self.config.adapter_path)
# Check for adapter files
adapter_files = list(adapter_path.glob("*.safetensors")) + list(adapter_path.glob("*.npz"))
if adapter_files:
print(f"✅ Found adapter files: {[f.name for f in adapter_files]}")
else:
print("❌ No adapter files found")
return False
# Check for configuration
config_file = adapter_path / "adapter_config.json"
if config_file.exists():
print(f"✅ Found adapter config: {config_file}")
# Display config contents
with open(config_file, 'r') as f:
config_data = json.load(f)
print(f" LoRA rank: {config_data.get('r', 'unknown')}")
print(f" LoRA alpha: {config_data.get('lora_alpha', 'unknown')}")
else:
print("⚠ No adapter config found")
# Calculate total size
total_size = sum(f.stat().st_size for f in adapter_path.rglob('*') if f.is_file())
print(f"📁 Total adapter size: {total_size / 1e6:.1f} MB")
return True
def main():
"""Main fine-tuning execution"""
print("🤖 MLX LoRA Fine-Tuning Pipeline")
print("=" * 50)
# Create configuration
config = FineTuningConfig()
# Create fine-tuner
fine_tuner = MLXFineTuner(config)
# Run fine-tuning
success, metadata = fine_tuner.run_fine_tuning()
if success:
# Verify output
fine_tuner.verify_training_output()
print("\n✨ Fine-tuning pipeline completed successfully!")
print("\n🎯 Next steps:")
print(" 1. Test your fine-tuned model")
print(" 2. Run evaluation to measure performance")
print(" 3. Build your application interface")
return metadata
else:
print("\n💥 Fine-tuning failed. Please check the error messages above.")
return None
if __name__ == "__main__":
metadata = main()
This content originally appeared on DEV Community and was authored by Prashant Nigam