Back to Home

Deconstructing Metal Kernels from OpenAI’s GPT-oss

Aug 16, 2025

Learn to write performant kernels in Metal Shader Langauge

Deconstructing Metal Kernels from OpenAI’s GPT-oss

I wanted to learn Metal Shader Language, Apple’s framework to working with GPUs on any Apple Device. However, compared to CUDA or ROCm, there are not many good resources available for learning. So I’ve been learning from reading kernels written by engineers at OpenAI, DeepSeek, and other AI labs. This series of articles will target those who know nothing about GPU programming to be able to write performant kernels for any use case. For this article, we’ll deconstruct the RMSNorm kernel written for GPT-oss to get a basic understanding of how these kernels improve performance.

Quick Overview of ML Inference

OpenAI released two models – a 30B parameter model and a 120B parameter model. Traditional LLMs use every weight to perform computations. This is expensive in memory bandwidth and TFLOPs(Tera Floating Point Operations Per Second). GPT-oss models are Mixture of Experts (MoE) models. MoE models are unique because instead of having 1 large feed-forward block, they have 8 “expert” feedforward block trained on a subset of data to improve their performance on specific areas. When tokens are passed through, they are routed by a router to the optimal expert. Thus, rather than have all parameters be active, only a subset of the parameters is used in inference. For GPT-oss-120B, only 5.6B parameters are used and for GPT-oss-30B, only 3.1 parameters are used.

As with all transformer models, they have attention blocks that allow the model to pick tokens to attend to. Traditional transformers use multi-headed attention, where each head computes a key(K), query(Q), and value(V) vector and then calculates attention. However, this can be expensive, so researchers came up with multi-query attention, where each head computes a unique Q but shares all K and V. This sacrifices richness of representation for faster inference. However, this can harm response quality. So, to compromise, researchers created group-query attention, a middle ground where heads are divided into groups that share K and V but compute unique Q.

RMSNorm

RMSnorm is a simple algorithm that takes in an input tensor and a weights tensor and normalizes after each sublayer in a transformer. It normalizes the input tensor by dividing by the root mean squared value. This normalizes all the values, in practice, we want the ability to modify weights individually to scale up and down certain features in LLMs, so we multiply the scaled vector by learned weights.

Prior to 2017, it was common practice to apply RMSNorm after blocks; now, in most LLMs, it is applied before blocks. This is because gradients struggle to propagate back in deeper models, which makes models hard to train with out other tricks (learning-rate warmups, gradient clipping, etc.)

Let's break down the code! The kernel provides a set of instructions executed by every thread on the GPU. Metal allows you to define the max_threads_per_threadgroup within the kernel.

[[max_total_threads_per_threadgroup(1024)]] kernel void gptoss_f32_bf16w_rmsnorm( // Arg syntax: <dtype> <name> <[[ buffer(idx) ]]> // Constant: Stored in constant GPU memory (accessible by any part of the GPU) constant gptoss_rmsnorm_args& args [[ buffer(0) ]], // float4* is a datatype containing 4 floats of 4 bytes (16 bytes total) const device float4* input [[ buffer(1) ]], const device bfloat4* weights [[ buffer(2) ]], device float4* output [[ buffer(3) ]], // Pass in params related to thread position in memory uint gid [[threadgroup_position_in_grid]], uint tid [[thread_position_in_threadgroup]], uint threadgroup_size [[ threads_per_threadgroup ]])

The body of the kernel starts by offsetting pointers to input and output to the right parts for the specific thread

// Assumes each SIMD group is 32 threads const uint simdgroup_size = 32; // Allocates shared memory big enough for 32 floats threadgroup float threadgroup_buffer[32]; // Threadgroup handles 1 row of contiguous float4 values // args.num_vecs is how many float4 elements belong to one row handled by one threadgrou // Pointer arithmetic to move it to &input[gid * num_vecs] // threadgroup_idx * numelems per threadgroup input += gid * args.num_vecs; output += gid * args.num_vecs; // Declares 4 floats initialized to 0 float4 sumsq4 = 0.0f; // Loop until you've handled all the elements in your group // Jump past total threads in group

Now we start to get into some important GPU optimizations. sumsq4 is a float4 with 4 values. Previously, when we looped over and used fma on the 4 values in each iteration, each value got separated into its own bucket. While we could loop over all values, this is slow. In GPU programming and distributed training, we prefer smarter reductions. We use tree-reduces where you recursively merge adjacent values into one until you’re left with on value. This is what is done in the following code to get sumsq. Since this kernel runs on each thread, we don’t have any way to have the threads to communicate with each other. This is where metal::simd_sum(sumsq) comes in allowing these values to be summed up across a simd_group. Then, if the thread is the first in the SIMD group, that thread writes to memory to avoid redundant writes. The simgroup index is computed so that all simd_groups write to adjacent parts of memory for future processing. Finally, the sums must all be written to memory, so threadgroup_barrier waits for all writes to finish processing before continuing to process instructions

// Tree-reduce sumsq within thread, then all-reduce within threadgroup. // [x^2+y^2, z^2+w^2] const float2 sumsq2 = sumsq4.xy + sumsq4.zw; // x^2+y^2+z^2+w^2 float sumsq = sumsq2.x + sumsq2.y; // Warning: this all-reduce works only for simdgroup of 32 threads and threadgroup of 32*32=1024 threads. // Sums across all lanes in the SIMD group and sums it sumsq = metal::simd_sum(sumsq); // Have one lane write to // For the first lane if (metal::simd_is_first()) { // (0...1024) // 32 const uint simdgroup_idx = tid / simdgroup_size; // Results in threadgroup_buffer of size 32 threadgroup_buffer[simdgroup_idx] = sumsq; } // SIMD must write their partial sums to barrier // This is barrier sync metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);

Now, all that is left to do is have the sum of all threadgroup sums. We do this with simd_sum again and then divide the total value we computed by the number of dimensions in the input tensor. Then we square root our sum, adding a small epsilon for numerical stability to get our scaling factor. Then each thread manipulates parts of the input to scale it by the scaling factor, and then multiply by our weights in w.

// Get the simdgroup_tid (0...32) const uint simdgroup_tid = tid % simdgroup_size; // Get the simdgroup_tid (0...32) sumsq = threadgroup_buffer[simdgroup_tid]; // Resum all the values using simdsum sumsq = metal::simd_sum(sumsq); // Divide value by number of channels const float avgsq = sumsq / args.num_channels; // meta has precise::rsqrt and fast::rsqrt (precise is slower but safer) // Epsilon is used for numerical stability const float scale = metal::precise::rsqrt(avgsq + args.epsilon); // Same loop condition as above for (uint i = tid; i < args.num_vecs; i += threadgroup_size) { // Multiply each value by the normalization factor const float4 val = input[i] * scale; // Cast weights from bf16 to float4 on read const float4 weight_val = static_cast<float4>(weights[i]); // Output float4 * float4 multiplication output[i] = val * weight_val; }

Now, all that is left to do is sum all threadgroup sums. We do this with simd_sum again and then divide the total value we computed by the number of dimensions in the input tensor. Then we square root our sum, adding a small epsilon for numerical stability to get our scaling factor. Then each thread manipulates parts of the input to scale it by the scaling factor, and then multiply by our weights in w.

// Get the simdgroup_tid (0...32) const uint simdgroup_tid = tid % simdgroup_size; // Get the simdgroup_tid (0...32) sumsq = threadgroup_buffer[simdgroup_tid]; // Resum all the values using simdsum sumsq = metal::simd_sum(sumsq); // Divide value by number of channels const float avgsq = sumsq / args.num_channels; // meta has precise::rsqrt and fast::rsqrt (precise is slower but safer) // Epsilon is used for numerical stability const float scale = metal::precise::rsqrt(avgsq + args.epsilon); // Same loop condition as above for (uint i = tid; i < args.num_vecs; i += threadgroup_size) { // Multiply each value by the normalization factor const float4 val = input[i] * scale; // Cast weights from bf16 to float4 on read const float4 weight_val = static_cast<float4>(weights[i]); // Output float4 * float4 multiplication output[i] = val * weight_val; }