The Magic of LoRA Fine-Tuning with MLX (Part 4)



This content originally appeared on DEV Community and was authored by Prashant Nigam

This is where the magic happens! In this part, we will deep dive into LoRA (Low-Rank Adaptation) fine-tuning and use MLX to train our model with incredible efficiency on Apple Silicon.

Understanding LoRA: The Game-Changing Technique

Imagine you are a master chef who wants to learn a new cuisine. Instead of forgetting everything you know and starting from scratch, you add new techniques and flavor profiles to your existing knowledge. That’s exactly what LoRA (Low-Rank Adaptation) does for language models.

The Traditional Fine-Tuning Problem

Traditional fine-tuning updates all 1.7 billion parameters of our model. This means:

  • ❌ Massive memory requirements
  • ❌ Slow training
  • ❌ Risk of “catastrophic forgetting” (losing general knowledge)
  • ❌ Large model files

The LoRA Solution

LoRA adds small “adapter” layers that learn new behaviors while keeping the original model frozen:

  • ✅ Minimal memory usage
  • ✅ Fast training
  • ✅ Preserves general knowledge
  • ✅ Tiny adapter file size
  • ✅ Can be combined or switched out easily

How LoRA Works Under the Hood

Think of the original model as a Swiss Army knife with all its tools welded in place. LoRA adds new attachments that can be snapped on or off.

MLX: Apple’s Secret Weapon for AI

MLX is Apple’s machine learning framework designed specifically for Apple Silicon. It’s what makes our local fine-tuning possible and incredibly fast.

Why MLX is good for Local AI

  1. Unified Memory Architecture: M-series chips share memory between CPU and GPU, eliminating data transfer bottlenecks
  2. Optimized Computation: Hand-tuned for Apple Silicon’s specific capabilities
  3. Memory Efficiency: Intelligent memory management for maximum model sizes
  4. Python Integration: Easy to use while being incredibly fast

Setting Up Our Fine-Tuning Pipeline

Let us build our fine-tuning system step by step, understanding each component.

Step 1: Configuration and Setup

First, let’s create a comprehensive configuration system:

touch fine_tuning_config.py

# Create fine_tuning_config.py
import os
from pathlib import Path
import mlx.core as mx

class FineTuningConfig:
    """Centralized configuration for fine-tuning"""

    def __init__(self):
        # Model configuration
        self.base_model = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
        self.adapter_path = "./adapters/email_sentiment"

        # Data paths
        self.train_data_path = "./data/mlx_format/train.jsonl"
        self.valid_data_path = "./data/mlx_format/valid.jsonl"

        # LoRA parameters
        self.lora_layers = 16  # Number of transformer layers to add LoRA to
        self.lora_rank = 16    # The 'r' in LoRA - higher = more capacity but slower
        self.lora_alpha = 32   # Scaling factor for LoRA adapters

        # Training parameters
        self.batch_size = 2           # Batch size (reduce if out of memory)
        self.learning_rate = 5e-5     # Learning rate
        self.max_iters = 1000         # Maximum training iterations
        self.steps_per_report = 10    # How often to print progress
        self.steps_per_eval = 200     # How often to run validation
        self.save_every = 400         # How often to save checkpoints

        # Hardware optimization
        self.use_gpu = mx.metal.is_available()
        self.max_sequence_length = 2048

        # Create directories
        Path(self.adapter_path).mkdir(parents=True, exist_ok=True)

    def print_config(self):
        """Print current configuration"""
        print("🔧 Fine-tuning Configuration:")
        print(f"  Base model: {self.base_model}")
        print(f"  GPU available: {self.use_gpu}")
        print(f"  LoRA rank: {self.lora_rank}")
        print(f"  LoRA layers: {self.lora_layers}")
        print(f"  Batch size: {self.batch_size}")
        print(f"  Learning rate: {self.learning_rate}")
        print(f"  Max iterations: {self.max_iters}")
        print(f"  Adapter path: {self.adapter_path}")

# Create and test config
if __name__ == "__main__":
    config = FineTuningConfig()
    config.print_config()

Step 2: Memory and Performance Monitoring

Before we start fine-tuning, let’s create tools to monitor our system:

touch monitoring.py

# Create monitoring.py
import time
import mlx.core as mx
from typing import Dict, List
import psutil

class PerformanceMonitor:
    """Monitor memory usage and training performance"""

    def __init__(self):
        self.start_time = time.time()
        self.metrics = []

    def log_memory_usage(self, step: int, loss: float = None):
        """Log current memory and performance metrics"""

        # GPU memory (if available)
        gpu_memory = {}
        if mx.metal.is_available():
            gpu_memory = {
                'active_mb': mx.metal.get_active_memory() / 1e6,
                'peak_mb': mx.metal.get_peak_memory() / 1e6
            }

        # System memory
        system_memory = psutil.virtual_memory()

        # Training metrics
        elapsed = time.time() - self.start_time

        metrics = {
            'step': step,
            'elapsed_seconds': elapsed,
            'loss': loss,
            'gpu_active_mb': gpu_memory.get('active_mb', 0),
            'gpu_peak_mb': gpu_memory.get('peak_mb', 0),
            'system_memory_percent': system_memory.percent,
            'system_memory_available_gb': system_memory.available / 1e9
        }

        self.metrics.append(metrics)

        if step % 50 == 0:  # Print every 50 steps
            self.print_status(metrics)

        return metrics

    def print_status(self, metrics: Dict):
        """Print current training status"""

        print(f"Step {metrics['step']:4d} | "
              f"Loss: {metrics['loss']:.4f} | "
              f"GPU: {metrics['gpu_active_mb']:.0f}MB | "
              f"Time: {metrics['elapsed_seconds']:.1f}s")

    def get_training_summary(self):
        """Get summary of training run"""

        if not self.metrics:
            return {}

        peak_gpu = max(m['gpu_peak_mb'] for m in self.metrics)
        total_time = self.metrics[-1]['elapsed_seconds']
        final_loss = self.metrics[-1]['loss']

        return {
            'total_training_time': total_time,
            'peak_gpu_memory_mb': peak_gpu,
            'final_loss': final_loss,
            'steps_completed': len(self.metrics)
        }

Step 3: The Fine-Tuning Engine

Now let’s create our main fine-tuning script using MLX-LM:

touch fine_tune_model.py

# Create fine_tune_model.py
import subprocess
import time
import json
import os
from pathlib import Path
from fine_tuning_config import FineTuningConfig
from monitoring import PerformanceMonitor

class MLXFineTuner:
    """Fine-tune models using MLX with LoRA"""

    def __init__(self, config: FineTuningConfig):
        self.config = config
        self.monitor = PerformanceMonitor()

    def validate_data(self):
        """Validate that training data exists and is properly formatted"""

        print("📊 Validating training data...")

        if not os.path.exists(self.config.train_data_path):
            raise FileNotFoundError(f"Training data not found: {self.config.train_data_path}")

        # Count training examples
        train_count = 0
        with open(self.config.train_data_path, 'r') as f:
            for line in f:
                if line.strip():
                    train_count += 1

        print(f"✅ Found {train_count} training examples")

        # Validate format
        with open(self.config.train_data_path, 'r') as f:
            first_line = f.readline()
            try:
                example = json.loads(first_line)
                if 'text' not in example:
                    raise ValueError("Training data must have 'text' field")
                print("✅ Data format validated")
            except json.JSONDecodeError:
                raise ValueError("Training data must be valid JSONL format")

        return train_count

    def build_training_command(self):
        """Build the MLX-LM training command"""

        cmd = [
            "python3", "-m", "mlx_lm", "lora",
            "--model", self.config.base_model,
            "--train",
            "--data", "./data/mlx_format",  # Directory containing train.jsonl
            "--batch-size", str(self.config.batch_size),
            "--iters", str(self.config.max_iters),
            "--learning-rate", str(self.config.learning_rate),
            "--steps-per-report", str(self.config.steps_per_report),
            "--steps-per-eval", str(self.config.steps_per_eval),
            "--adapter-path", self.config.adapter_path,
            "--save-every", str(self.config.save_every)
        ]

        return cmd

    def run_fine_tuning(self):
        """Execute the fine-tuning process"""

        print("🚀 Starting LoRA fine-tuning with MLX...")
        print("=" * 60)

        # Validate everything is ready
        train_count = self.validate_data()
        self.config.print_config()

        # Build command
        cmd = self.build_training_command()
        print(f"\n📝 Command: {' '.join(cmd)}")

        # Start training
        start_time = time.time()

        print(f"\n🏃 Training started at {time.strftime('%H:%M:%S')}")
        print(f"📚 Training on {train_count} examples")
        print("💡 This typically takes 3-10 minutes on Apple Silicon M3")
        print("⏰ Progress will be reported every 10 steps\n")

        try:
            # Run the training command
            result = subprocess.run(cmd, capture_output=True, text=True, check=True)

            training_time = time.time() - start_time

            print("\n" + "="*60)
            print("🎉 Fine-tuning completed successfully!")
            print(f"⏱  Total training time: {training_time:.1f} seconds")
            print(f"💾 Adapters saved to: {self.config.adapter_path}")

            # Save training metadata
            metadata = {
                'model_name': self.config.base_model,
                'training_time_seconds': training_time,
                'training_examples': train_count,
                'lora_rank': self.config.lora_rank,
                'lora_layers': self.config.lora_layers,
                'batch_size': self.config.batch_size,
                'learning_rate': self.config.learning_rate,
                'max_iters': self.config.max_iters,
                'timestamp': time.time(),
                'command_used': ' '.join(cmd)
            }

            metadata_path = f"{self.config.adapter_path}/training_metadata.json"
            with open(metadata_path, 'w') as f:
                json.dump(metadata, f, indent=2)

            print(f"📊 Training metadata saved to: {metadata_path}")

            # Parse and display training output
            self.parse_training_output(result.stdout)

            return True, metadata

        except subprocess.CalledProcessError as e:
            print("\n❌ Fine-tuning failed!")
            print(f"Error code: {e.returncode}")
            print(f"Error output: {e.stderr}")
            print(f"Standard output: {e.stdout}")
            return False, None

    def parse_training_output(self, output: str):
        """Parse and display key information from training output"""

        print("\n📈 Training Progress Summary:")
        print("-" * 40)

        lines = output.split('\n')

        # Look for key training metrics
        for line in lines:
            if 'Loss:' in line or 'Validation' in line:
                print(f"  {line.strip()}")

        # Look for final metrics
        for line in reversed(lines):
            if 'Loss:' in line:
                print(f"\n🎯 Final training loss: {line.split('Loss:')[-1].strip()}")
                break

    def verify_training_output(self):
        """Verify that training produced the expected files"""

        print("\n🔍 Verifying training output...")

        adapter_path = Path(self.config.adapter_path)

        # Check for adapter files
        adapter_files = list(adapter_path.glob("*.safetensors")) + list(adapter_path.glob("*.npz"))
        if adapter_files:
            print(f"✅ Found adapter files: {[f.name for f in adapter_files]}")
        else:
            print("❌ No adapter files found")
            return False

        # Check for configuration
        config_file = adapter_path / "adapter_config.json"
        if config_file.exists():
            print(f"✅ Found adapter config: {config_file}")

            # Display config contents
            with open(config_file, 'r') as f:
                config_data = json.load(f)
                print(f"   LoRA rank: {config_data.get('r', 'unknown')}")
                print(f"   LoRA alpha: {config_data.get('lora_alpha', 'unknown')}")
        else:
            print("⚠  No adapter config found")

        # Calculate total size
        total_size = sum(f.stat().st_size for f in adapter_path.rglob('*') if f.is_file())
        print(f"📁 Total adapter size: {total_size / 1e6:.1f} MB")

        return True

def main():
    """Main fine-tuning execution"""

    print("🤖 MLX LoRA Fine-Tuning Pipeline")
    print("=" * 50)

    # Create configuration
    config = FineTuningConfig()

    # Create fine-tuner
    fine_tuner = MLXFineTuner(config)

    # Run fine-tuning
    success, metadata = fine_tuner.run_fine_tuning()

    if success:
        # Verify output
        fine_tuner.verify_training_output()

        print("\n✨ Fine-tuning pipeline completed successfully!")
        print("\n🎯 Next steps:")
        print("  1. Test your fine-tuned model")
        print("  2. Run evaluation to measure performance")
        print("  3. Build your application interface")

        return metadata
    else:
        print("\n💥 Fine-tuning failed. Please check the error messages above.")
        return None

if __name__ == "__main__":
    metadata = main()


This content originally appeared on DEV Community and was authored by Prashant Nigam