AI Machine Learning Optimization On-Device

Model Distillation: Running Smaller AI Models in Browser Extensions

E
Extendable Team
· 14 min read

Large language models provide impressive capabilities, but their size makes them impractical for browser extensions. Model distillation lets you create smaller, specialized models that run faster and can even operate entirely on-device. This guide covers practical approaches to bringing distilled AI models to your extensions.

Why Distillation for Extensions?

Browser extensions face unique constraints:

  • Bundle size: Extensions over 10MB are slow to install
  • Memory: Browsers limit extension memory usage
  • Privacy: Users prefer on-device processing
  • Latency: API round-trips add 200ms+ delay
  • Costs: API calls add up quickly
Distillation Trade-offs:
  • Smaller models = faster inference but lower capability
  • Task-specific models excel at one thing
  • On-device = private but limited by hardware
  • Hybrid approaches often work best

Understanding Model Distillation

Distillation transfers knowledge from a large “teacher” model to a smaller “student” model:

Teacher Model (GPT-4, 1.7T params)

         │ Generate training data
         │ Soft labels (probabilities)

    ┌─────────────┐
    │  Training   │
    │   Process   │
    └─────────────┘


Student Model (100M params)
  - Faster inference
  - Smaller size
  - Task-specific

Practical Distillation Approaches

1. Use Pre-Distilled Models

Start with models already distilled for browser use:

// Using TensorFlow.js with pre-trained models
import * as tf from '@tensorflow/tfjs';
import * as use from '@tensorflow-models/universal-sentence-encoder';

class SemanticSearch {
  constructor() {
    this.model = null;
    this.initialized = false;
  }

  async initialize() {
    // Load distilled sentence encoder (~20MB)
    this.model = await use.load();
    this.initialized = true;
  }

  async embed(texts) {
    if (!this.initialized) await this.initialize();
    const embeddings = await this.model.embed(texts);
    return embeddings.arraySync();
  }

  async findSimilar(query, documents, topK = 5) {
    const queryEmbed = await this.embed([query]);
    const docEmbeds = await this.embed(documents);

    // Calculate cosine similarity
    const similarities = docEmbeds.map((docEmbed, i) => ({
      index: i,
      score: this.cosineSimilarity(queryEmbed[0], docEmbed)
    }));

    return similarities
      .sort((a, b) => b.score - a.score)
      .slice(0, topK);
  }

  cosineSimilarity(a, b) {
    let dot = 0, normA = 0, normB = 0;
    for (let i = 0; i < a.length; i++) {
      dot += a[i] * b[i];
      normA += a[i] * a[i];
      normB += b[i] * b[i];
    }
    return dot / (Math.sqrt(normA) * Math.sqrt(normB));
  }
}

2. Distill for Specific Tasks

Create task-specific models from larger ones:

# Python script to generate training data using teacher model
import openai
import json

def generate_training_data(task_examples, num_samples=1000):
    training_data = []

    for example in task_examples:
        # Use GPT-4 as teacher
        response = openai.ChatCompletion.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": example["system"]},
                {"role": "user", "content": example["user"]}
            ],
            temperature=0.7,
            n=5  # Generate multiple responses
        )

        for choice in response.choices:
            training_data.append({
                "input": example["user"],
                "output": choice.message.content,
                "task": example["task"]
            })

    return training_data

# Example: Distill for sentiment analysis
sentiment_examples = [
    {
        "system": "Classify the sentiment as positive, negative, or neutral.",
        "user": "This product exceeded my expectations!",
        "task": "sentiment"
    },
    # ... more examples
]

data = generate_training_data(sentiment_examples)

# Save for student model training
with open("training_data.json", "w") as f:
    json.dump(data, f)

3. Fine-Tune Small Models

Fine-tune a small model on distilled data:

# Fine-tune DistilBERT on task-specific data
from transformers import (
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
    Trainer,
    TrainingArguments
)
from datasets import load_dataset

# Load pre-trained DistilBERT (66M params vs BERT's 110M)
model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=3  # positive, negative, neutral
)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# Load distilled training data
dataset = load_dataset("json", data_files="training_data.json")

def tokenize(examples):
    return tokenizer(examples["input"], padding="max_length", truncation=True)

tokenized = dataset.map(tokenize, batched=True)

# Train
training_args = TrainingArguments(
    output_dir="./distilled-sentiment",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    save_steps=500,
    evaluation_strategy="steps"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"]
)

trainer.train()

# Export for browser
model.save_pretrained("./distilled-sentiment-export")

Converting Models for Browser

ONNX Export

Convert PyTorch/TensorFlow models to ONNX for browser deployment:

# Export to ONNX
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer

model = DistilBertForSequenceClassification.from_pretrained("./distilled-sentiment")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# Create dummy input
dummy_input = tokenizer("Sample text", return_tensors="pt")

# Export
torch.onnx.export(
    model,
    (dummy_input["input_ids"], dummy_input["attention_mask"]),
    "sentiment.onnx",
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {0: "batch", 1: "sequence"},
        "attention_mask": {0: "batch", 1: "sequence"},
        "logits": {0: "batch"}
    },
    opset_version=14
)

TensorFlow.js Conversion

# Install converter
pip install tensorflowjs

# Convert SavedModel to TF.js format
tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_format=tfjs_graph_model \
    --quantization_bytes=2 \
    ./saved_model \
    ./tfjs_model

Running Models in Extensions

ONNX Runtime Web

// background.js
import * as ort from 'onnxruntime-web';

class SentimentAnalyzer {
  constructor() {
    this.session = null;
    this.tokenizer = null;
  }

  async initialize() {
    // Load ONNX model
    this.session = await ort.InferenceSession.create(
      chrome.runtime.getURL('models/sentiment.onnx')
    );

    // Load tokenizer vocabulary
    const vocab = await fetch(chrome.runtime.getURL('models/vocab.json'));
    this.vocab = await vocab.json();
  }

  tokenize(text) {
    // Simple tokenization (production would use proper tokenizer)
    const tokens = text.toLowerCase().split(/\s+/);
    const ids = tokens.map(t => this.vocab[t] || this.vocab['[UNK]']);

    // Pad/truncate to max length
    const maxLen = 128;
    const padded = ids.slice(0, maxLen);
    while (padded.length < maxLen) padded.push(0);

    return {
      input_ids: padded,
      attention_mask: padded.map(id => id > 0 ? 1 : 0)
    };
  }

  async analyze(text) {
    const { input_ids, attention_mask } = this.tokenize(text);

    const feeds = {
      input_ids: new ort.Tensor('int64', BigInt64Array.from(input_ids.map(BigInt)), [1, 128]),
      attention_mask: new ort.Tensor('int64', BigInt64Array.from(attention_mask.map(BigInt)), [1, 128])
    };

    const results = await this.session.run(feeds);
    const logits = results.logits.data;

    // Softmax to probabilities
    const probs = this.softmax(Array.from(logits));

    const labels = ['negative', 'neutral', 'positive'];
    const maxIdx = probs.indexOf(Math.max(...probs));

    return {
      label: labels[maxIdx],
      confidence: probs[maxIdx],
      probabilities: Object.fromEntries(labels.map((l, i) => [l, probs[i]]))
    };
  }

  softmax(arr) {
    const max = Math.max(...arr);
    const exp = arr.map(x => Math.exp(x - max));
    const sum = exp.reduce((a, b) => a + b);
    return exp.map(x => x / sum);
  }
}
Model Size Guidelines:
  • <5MB: Instant load, no user notice
  • 5-20MB: Show loading indicator
  • 20-50MB: Load on demand, cache aggressively
  • >50MB: Consider hybrid approach

WebGPU Acceleration

For larger models, use WebGPU when available:

async function initializeWithGPU() {
  // Check WebGPU availability
  if (!navigator.gpu) {
    console.log('WebGPU not available, falling back to CPU');
    return initializeCPU();
  }

  const adapter = await navigator.gpu.requestAdapter();
  const device = await adapter.requestDevice();

  // Configure ONNX runtime for WebGPU
  const session = await ort.InferenceSession.create(
    modelPath,
    {
      executionProviders: ['webgpu'],
      graphOptimizationLevel: 'all'
    }
  );

  return session;
}

Hybrid Architecture

Combine on-device and API models:

class HybridAI {
  constructor() {
    this.localModel = new SentimentAnalyzer();
    this.apiEndpoint = 'https://api.yourservice.com/ai';
  }

  async analyze(text, options = {}) {
    const { preferLocal = true, complexityThreshold = 0.7 } = options;

    // Quick local analysis first
    const localResult = await this.localModel.analyze(text);

    // Use local result if confident
    if (preferLocal && localResult.confidence > complexityThreshold) {
      return { ...localResult, source: 'local' };
    }

    // Fall back to API for complex cases
    try {
      const apiResult = await this.queryAPI(text);
      return { ...apiResult, source: 'api' };
    } catch (error) {
      // Use local result as fallback
      return { ...localResult, source: 'local', fallback: true };
    }
  }

  async queryAPI(text) {
    const response = await fetch(this.apiEndpoint, {
      method: 'POST',
      headers: { 'Content-Type': 'application/json' },
      body: JSON.stringify({ text })
    });
    return response.json();
  }
}

Quantization for Size Reduction

Reduce model size with quantization:

# Quantize ONNX model
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

# Load model
model_path = "sentiment.onnx"
quantized_path = "sentiment_quantized.onnx"

# Dynamic quantization (no calibration data needed)
quantize_dynamic(
    model_path,
    quantized_path,
    weight_type=QuantType.QUInt8
)

# Check size reduction
import os
original_size = os.path.getsize(model_path) / (1024 * 1024)
quantized_size = os.path.getsize(quantized_path) / (1024 * 1024)
print(f"Original: {original_size:.1f}MB → Quantized: {quantized_size:.1f}MB")
print(f"Reduction: {(1 - quantized_size/original_size) * 100:.1f}%")

Performance Optimization

Lazy Loading

class LazyModel {
  constructor(modelPath) {
    this.modelPath = modelPath;
    this.model = null;
    this.loading = null;
  }

  async ensureLoaded() {
    if (this.model) return this.model;

    if (!this.loading) {
      this.loading = this.load();
    }

    return this.loading;
  }

  async load() {
    console.log('Loading model...');
    this.model = await ort.InferenceSession.create(this.modelPath);
    console.log('Model loaded');
    return this.model;
  }

  async infer(input) {
    await this.ensureLoaded();
    return this.model.run(input);
  }
}

Caching Inference Results

class CachedInference {
  constructor(model, maxCacheSize = 1000) {
    this.model = model;
    this.cache = new Map();
    this.maxSize = maxCacheSize;
  }

  getCacheKey(input) {
    return JSON.stringify(input);
  }

  async infer(input) {
    const key = this.getCacheKey(input);

    if (this.cache.has(key)) {
      return this.cache.get(key);
    }

    const result = await this.model.infer(input);

    // LRU eviction
    if (this.cache.size >= this.maxSize) {
      const firstKey = this.cache.keys().next().value;
      this.cache.delete(firstKey);
    }

    this.cache.set(key, result);
    return result;
  }
}

Summary

Model distillation enables powerful AI features in browser extensions without the latency and privacy concerns of API calls. Start with pre-distilled models, fine-tune for your specific tasks, and use hybrid architectures to balance capability with speed.

Key strategies:

  • Use pre-distilled models (USE, DistilBERT) when possible
  • Create task-specific models for better size/performance
  • Quantize models for 2-4x size reduction
  • Implement hybrid local/API approaches
  • Cache results aggressively
  • Load models lazily on first use