Writing
A ChatGPT clone, in 3000 bytes of C, backed by GPT-2

by Nicholas Carlini 2023-04-02



This program is a dependency-free implementation of GPT-2. It loads the weight matrix and BPE file out of the original TensorFlow files, tokenizes the input with a simple byte-pair encoder, implements a basic linear algebra package with matrix math operations, defines the transformer architecture, performs transformer inference, and un-tokenizes the output with the BPE decoder. All in ~3000 bytes of C.

It's optimized efficiently enough so that GPT-2 Small takes a few seconds per reply on any modern machine. To do this I've implemented KV caching and an efficient matrix multiplication algorithm, with optional OMP parallelism.

You can then use this to create something like Chat GPT---just so long as you don't care about the quality of the output. (It's actually pretty terrible output, objectively speaking... But it does run.) There are a few quirks (especially with handling UTF-8 characters), and running the XL size model at long context length can require ~100GB of RAM. But if you're just typing with ASCII using GPT2-Small it should run just about anywhere.

I've uploaded the code to GitHub, so feel free to try and use it there.

This program is made up of the following main blocks (hover over each to see the coresponding code): Basic matrix math library (700 bytes) Fast matrix multiplication (300 bytes) Neural network layers (300 bytes) Transformer model (600 bytes) Byte pair encoding (400 bytes) I/O (200 bytes) Weight loading (300 bytes) Byte pair encoding loading (300 bytes)

#include<stdio.h>

#include<stdlib.h>

#include<string.h>

#include<math.h>

int U,C,K,c,d,S,zz;char*bpe;typedef struct{float*i;int j,k;} A;void*E,*n;A*f;FILE*fp;

#define N(i,j)for(int i=0; i<j; i++)

A o(int j,int k,int i){float*a=E;E+=S=4*j*k;memset(a,0,S*i);A R={ a,j,k} ;return R;}

#define I(R,B)A R(A a,float k){ N(i,a.j*a.k){ float b=a.i[i]; a.i[i]=B; } return a; }

I(l,b/k)I(q,b+k)I(u,1./sqrt(b))I(z,exp(b))I(r,a.i[(i/a.k)*a.k])I(P,(i/k<i%(int)k)?0:exp(b/8))I(Q,b/2*(1+tanh(.7978845*(b+.044715*b*b*b))))

#define F(R,B)A R(A a,A b){ N(i,a.j*a.k){ a.i[i]=a.i[i]B b.i[i]; } return a; }

F(V,+)F(v,*)F(H,/)F(at,+b.i[i%a.k];)F(mt,*b.i[i%a.k];)A X(A a){A R=o(a.j,a.k,1);N(i,a.j*a.k)R.i[(i/a.k)*a.k]+=a.i[i];r(R,0);return R;}A p(A a){A R=o(a.k,a.j,1);N(i,a.j*a.k)R.i[i%a.k*a.j+i/a.k]=a.i[i];return R;}

A g(A a,A b){A R=o(a.j,b.j,!c);{for(int i=c;i<d;i++){for(int j=0;j<b.j;j+=4){for(int k=0;k<a.k;k+=4){N(k2,4)N(j2,4)R.i[i*b.j+j+j2]+=a.i[i*a.k+k+k2]*b.i[(j+j2)*b.k+k+k2];}}}}return

V(o(R.j,R.k,1),R);}A J(A a,int b,int j,int k){A R={ a.i+b*j,j,k} ;return R;}A s(A a,int i){A b=V(a,l(X(a),-a.k));A k=l(X(v(V(o(b.j,b.k,1),b),b)),b.k-1);A R=at(mt(v(V(o(b.j,b.k,1),b),u(q(k,1e-5),0)),f[i+1]),f[i]);return R;}

#define G(a,i)at(g(a,f[i+1]),f[i])

A m(int j,int k){j+=!j;k+=!k;A a=o(j,k,1);fread(a.i,S,1,fp);return p(a);}

int t;int Y(char*R){if(!*R)return 0;int B=1e9,r;N(i,5e4){if(bpe[999*i]&&strncmp(bpe+999*i,R,S=strlen(bpe+999*i))==0){int k=Y(R+S)+i+1e7;if(k<B){B=k;r=i;}}}t=r;return B;}int *w(char*q,int*B){char R[1000];int i=0;while(q[i]){int j=i++;while(47<q[i]&&q[i]<58||64<q[i]){fflush(stdout);i++;}strcpy(R,q+j);R[i-j]=0;fflush(stdout);int k=0;while(R[k]){Y(R+k);char*M=bpe+t*999;k+=strlen(M);*B++=t;}}return B;}

int main(int S,char**D){S=D[1][5]+3*D[1][7]+3&3;K=12+4*S+(S>2);U=K*64;C=12*S+12;zz=atoi(D[4]);E=malloc(2LL*U*U*C*zz);

bpe=malloc(1e9);fp=fopen(D[2],"r");unsigned char a[S=999],b[S];N(i,5e4){int k=i*S;if(i<93){bpe[k]=i+33;bpe[k+1]=0;} else if(i>254){fscanf(fp,"%s %s",a,b);strcat((char*)a,(char*)b);int j=0;N(i,a[i])bpe[k+j++]=a[i]^196?a[i]:a[++i]-128;bpe[k+j++]=0;} else if(i>187){bpe[k]=i-188;bpe[k+1]=0;}}int e[1024];d=w(D[3],e)-e;int h;N(i,d){if(e[i]==18861)h=i+1;}printf("AI");N(i,d-h)printf("%s",bpe+e[i+h]*999);

fp=fopen(D[1],"r");A\

 x[999];A*R=x;N(i,C){N(j,12)*R++=m(U+U*(j?j^8?j^11?0:3:3:2),U*((j%8==3)+3*(j%8==1)+(j==9)));}*R++=m(U,1);*R++=m(U,1);A QA=m(1024,U),Z=p(m(5e4,U));

while(1){char W[1000]={ 0} ;int T;strcat(W,"\nAlice: ");printf("\n%s: ",bpe+20490*999);fflush(stdout);fgets(W+8,1000,stdin);printf("AI:");strcat(W,"\nBob:");d=w(W,e+d)-e;n=E;c=0;

while(1){E=n;T=d+32-d%32;c*=!!(d%32);A O=o(T,U,1);N(i,d){N(j,U)O.i[i*U+j]=Z.i[e[i]*U+j]+QA.i[j*1024+i];}N(i,C){int y;S=0;N(j,10){if(j==i)y=S;S++;N(k,10*(j>0)){if(j*10+k<C&&S++&&i==j*10+k)y=S;}}f=x+12*y;A QB=p(J(G(s(O,4),0),0,T*3,U));A B=o(U,T,1);N(k,K){A L=p(J(QB,k*3,64*T,3)),a=P(g(p(J(L,0,64,T)),p(J(L,T,64,T))),T),R=p(g(H(a,X(a)),J(L,T*2,64,T)));memcpy(B.i+64*T*k,R.i,64*T*4);}O=V(O,G(p(B),2));O=V(O,G(Q(G(s(O,6),8),0),10));}f=x;O=s(O,12*C);c=0;int S=d;d=1;A B=g(p(J(O,S-1,U,1)),Z);c=d=S;S=0;N(i,5e4){if(B.i[i]>B.i[S])S=i;}if(d==zz){memcpy(e,e+zz/2,S*2);d-=zz/2;c=0;}e[d++]=S;

if(bpe[S*999]==10)break;printf("%s",bpe+S*999);fflush(stdout);}}}


Background: ChatGPT and transformers

In case you've been living under a rock for the past few months, ChatGPT is an application where you can talk to a type of machine learning model called a "language model" as if it was another person. It responds remarkably well, and GPT-4, the latest model that powers ChatGPT, is incredibly impressive.

This C program implements the behavior of ChatGPT using a much weaker model from 2019: GPT-2. Despite being just 2 smaller than GPT-4, it has no where near the same capabilities---but it is open source. So it has that going for it.

GPT-2 is a type of machine learning model called a "transformer". These neural networks take a fixed-size sequence of words as input, and predict the next word that will occur. By repeating the procedure over and over, you can use them to generate arbitrary-length sequences.

This post isn't meant to be an introduction to all the machine learning you'll need to know why a transformer is designed the way it is, but the rest of this post will be dedicated to describing how the above C code works.


Walkthrough the C Code

Getting started: Matrix Math (700 bytes)

Seeing as neural networks are just matrix operations. So we're going to need to get started by building a matrix library in as few bytes as possible.

My definition of a matrix is completely minimal:

typedef struct {

  float* dat;

  int rows, cols;

} Matrix;

We'll begin by by observing that while there are a bunch of different operations we'll need to implement, there are basically two "types" of operations"

  1. Matrix-constant operations (e.g., add 7 to each entry of a matrix)
  2. Matrix-matrix operations (e.g., add corresponding matrix entries)

This similarity allows us to use macros to pull out a bunch of the common logic into a meta-routine that knows how to operate on, for example, pairs of matrices and just leaves the specific operator implementation defined.

To do this in C, I'll define the function

#define BINARY(function, operation)

as the following:

Matrix FUNCTION(Matrix a, Matrix b) {

  for (int i = 0; i < a.rows; i++) {

    for (int j = 0; j < a.cols; j++) {

      a[i*a.cols + j] = a[i*a.cols + j] OPERATION b[i*a.cols+j];

    }

  }

  return a;

}

And so for example this lets us just write

BINARY(matrix_elementwise_add, +);

BINARY(matrix_elementwise_multiply, *);

and have it automatically expand to the full operation that perform elementwise addition or multiplication of two matrices. I define a few other easy to understand operations as well:

Now the thing about C's #defines is they're basically just glorified regexs. So when we actually run this, what's going to happen is we're going to take the line

a[i*a.cols + j] = a[i*a.cols + j] OPERATION b[i*a.cols+j];

and expand for the case of multiplication expand it to

a[i*a.cols + j] = a[i*a.cols + j] * b[i*a.cols+j];

But this replacement is almost literally just a regular expression replace. We could have put anything in place of OPERATION. This allows us to define a function like

BINARY(add_tile, + b.dat[i%a.cols] ; )

Which at first glance looks rather confusing---what is that semi-colon doing there?---but if you just do a regular expression replace on it, you'll see that it expands to

  a[i*a.cols + j] = a[i*a.cols + j] + b.dat[i%a.cols] ; b[i*a.cols+j];

where because the second expression doesn't do anything this is just equivalent to

a[i*a.cols + j] = a[i*a.cols + j] + b.dat[i%a.cols] ; b[i*a.cols+j];

(TAKE THAT LANGUAGES WITH PROPER MACROS. LISP ISN'T ALWAYS BETTER THAN C!)

Fast matrix multiplication (300 bytes)

The basic implementation of matrix multiplication is entirely straightforward: we just implement the naive cubic-time three loops: (There's nothing intelligent about my matrix multiplication. If you know how to make matrix multiplication fast you can just move along.)

Matrix matmul(Matrix a, Matrix b) {

  Matrix out = NewMatrix(a.rows, b.rows);

  for (int i = 0; i < a.rows; i++)

    for (int j = 0; j < b.rows; j++)

      for (int k = 0; k < a.cols; k++)

        out.dat[i * b.rows + j] += a.dat[i * a.cols + k+k2] * b.dat[(j+j2) * b.cols + k];


  return out;

}

Fortunately we can make it much faster with just a few bits of intelligence. Because of the way memory and caches work on most computers, it's (much!) faster to read and write to the same piece of memory over and over.

Matrix matmul_t_fast(Matrix a, Matrix b) {

  Matrix out = NewMatrix(a.rows, b.rows);

  for (int i = 0; i < a.rows; i++)

    for (int j = 0; j < b.rows; j += 4)

      for (int k = 0; k < a.cols; k += 4)

        for (int k2 = 0; k2 < 4; k2 += 1)

          for (int j2 = 0; j2 < 4; j2 += 1)

            out.dat[i * b.rows + j+j2] += a.dat[i * a.cols + k+k2] * b.dat[(j+j2) * b.cols + k+k2];


  return out;

}

Later we're going to make one more change to the way we do inference and add a new parameter to the matrix multiply that instead allows us to only multiply part of Matrix A by Matrix B, which is useful when we've already pre-computed part of the product.

Neural network layers (300 bytes)

In order to write a transformer I'll need to define a few special neural-network specific layers. One of these is the GELU activation function, which you can just think of as magic.

UNARY(GELU, b / 2 * (1 + tanh(.7978845 * (b + .044715 * b * b * b))))

I also implement a function to set the lower-diagonal of a matrix (after exponentiating the values). This is useful for what's called causal attention: we only want to attend to the past, not the future, and so we set the lower diagonal of the attention matrix to zero with this function.

UNARY(tril, (i/k<i%(int)k) ? 0 : exp(b/8))

And finally we need a layer normalization function. (Again another piece of magic that you can look up if you want. Basically what it does is normalize the mean and variance of each layer.)

Matrix LayerNorm(Matrix a, int i) {

  Matrix b = add(a, divide_const(sum(a), -a.cols));

  Matrix k = divide_const(sum(multiply(

    add(NewMatrix(b.rows,b.cols,1),b), b)), b.cols-1);

  Matrix out = add_tile(multiply_tile(

    multiply(add(NewMatrix(b.rows,b.cols,1),b),

    mat_isqrt(add_const(k, 1e-5),0)), layer_weights[i+1]),

                        layer_weights[i]);


  return out;

}

The final piece of the model is the Linear function that just performs a matrix multiplication and adds (with tiling) a bias.

#define Linear(a, i) add_tile(matmul_t_fast(a, layer_weights[i+1]), layer_weights[i])

Transformer architecture (600 bytes)

With all of this out of the way, we can finally implement our transformer in just 600 bytes.

for (int i = 0; i < NLAYER; i++) {

  layer_weights = weights + 12*permute;


  // Compute the keys, queries, and values all at once with a big multiply

  Matrix qkv = transpose(slice(Linear(LayerNorm(line, 4), 0), 0, T*3, DIM));


  // Make space for the output of the computation 

  Matrix result = NewMatrix(DIM, T, 1);


  for (int k = 0; k < NHEAD; k++) {

    // Split the qkv into each of the heads 

    Matrix merge = transpose(slice(qkv, k*3, 64*T, 3)),

      // perform the product of the queries and keys and then exponentiate

      a = tril(matmul_t_fast(transpose(slice(merge, 0, 64, T)),

                             transpose(slice(merge, T, 64, T))), T),

      // finally multiply the softmax output (a/sum(a)) with the values matrix

      out = transpose(matmul_t_fast(divide(a, sum(a)), slice(merge, T*2, 64, T)));

    // and copy the output to the proper location in the result matrix

    memcpy(result.dat+64*T*k, out.dat, 64*T*4);

  }


  // Residual connection

  line = add(line,Linear(transpose(result), 2));


  // Activation function and residual connection

  line = add(line, Linear(GELU(Linear(LayerNorm(line, 6), 8), 0), 10));

}


// Reset layer weights so we can do the last layer norm 

layer_weights = weights;

line = LayerNorm(line, 12*NLAYER);


Matrix result = matmul_t_fast(transpose(slice(line, tmp-1, DIM, 1)), wte);


Now here's a fact that might not be completely obvious about transformer inference: once you've called the model to generate one token, you don't actually have to re-compute the entire function to generate the next token. In fact, you only need to do a small amount of additional work to generate each additional token.

This is because once you've computed the output of the transformer on the output of all the tokens up to the Nth token, you can re-use almost all of this output to compute the N+1st token (with a little bit more work.)

To actually implement this, I make all allocations in my code occur sequentially within the same block of memory, to guarantee that each matrix multiply will always use exactly the same memory. Then, at each iteration of the loop, I can just not zero-out the memory before using it for the next iteration, and the memory will already contain the results of the previous iteration. I just need to run the computation for the N+1st row.

Byte pair encoding (400 bytes)

The simplest way to build a language model is on a sequence of words. But because the total number of words is essentially unbounded, and language models need to have fixed-size inputs, it would be necessary to replace sufficiently rare words with a special [OUT OF DISTRIBUTION] token. This is no good.

While a simple “fix” for this would be to just use character-level language models that only know about individual letters, this would be a problem because it would mean that the model would have to learn the meaning of every word from scratch, and also reduces the effective context size of the language model by a factor of the average word length.

So to fix this, language models like GPT-2 work by creating tokens out of "word pieces". Some words might be tokens all by them-self, but rare words are broken up into smaller pieces. For example, the word “nicholas” might be broken up into “nich” “o” “las”.

The general algorithm for this is rather easy to implement: given a word we want to tokenize, we first split it into individual characters. Then, we look for pairs of adjacent tokens that should be merged, and merge them together. We repeat this until there are no more possible merges.

This algorithm is simple but unfortunately rather hard to implement in C because it requires a bunch of allocations, and requires keeping track of a tree-like structure of the tokens.

So instead, we'll turn the rather simple linear time algorithm into a potentially exponential time algorithm but save a bunch of code. Our basic idea will work like this in C-like pseudocode:

word_tokenize(word) {

  if len(word) == 0 { return (0, 0); }

  result = (1e9, -1);

  for (int i = 0; i < VOCAB_LEN; i++) {

    if (is_prefix(bpe[i]), word) {

      sub_cost = word_tokenize(word+len(bpe[i]))[0] + i + 1e7;

      result = min(result, (sub_cost, i));

    }

  }

  return result;

}

That is, to tokenize a word, we try each possible word in the vocabulary to see if it's a prefix of the current word. If so, we try to use this as the first token, and then recursively try to tokenize the rest of the word. We keep track of the best tokenization we've seen so far (as judged by the length, breaking ties by the index of the token in the vocab), and return that.

Weight loading (300 bytes)

We're almost done! The last thing we need to do is load the actual weights of the neural network off disk. This is actually pretty easy, because the weights are stored in a simple binary format that's easy to read in C: it's just a completely flat serialization of 32-bit floats.

The only thing we need to know is how big the various matrices are. And fortunately, this is also easy to figure out. Each of the GPT-2 model sizes have the same architecture, and the weights are saved in the same order, so all we need to do is read read the correctly-shaped matrices off of disk.

There's one final annoying thing. The layers of the neural network are not stored on disk in the order you might expect, with layer 0 first, then layer 1, then layer 2. Instead, the first layer is layer 0, then layer 1, and then layer .... TEN! (and then layer 11, and then layer 12.) This is because weights are stored when sorted lexicographically. And lexicographically, “0” comes before “1”, but “10” comes before “2”. So we have to do a bit of work to permute the weights into the correct order with the following code

int permute;

tmp=0;

for (int j = 0; j < 10; j++) {

  if (j == i) {

    permute = tmp;

  }

  tmp++;

  for (int k = 0; k < 10*(j>0); k++) {

    if (j*10+k < NLAYER && tmp++ && i == j*10+k) {

      permute = tmp;

    }

  }

}

Byte pair encoding loading (300 bytes)

In order to actually perform byte-pair encoding, we need to first load the byte-pair encoding vocabulary off disk. In an ideal world we'd actually just have a list of all the words in the vocabulary stored in some reasonable C-readable format, but because the original file was (a) meant for reading in Python, and (b) not meant to make it easy to parse in as few bytes as possible, we'll have to do some work here.

You might expect the file format to just be a list of words one after the other, but it's actually instead a list of the byte-pair encodings. What this means is instead of being able to read “Hello” as one token, the line is “H” “ello” which means we should be merging the tokens “H” and “ello” into a single token “Hello”.

The other challenge is that the file is encoded in smoothing-like UTF-8 (but not quite exactly that) for ... reasons. All of the printable ascii characters are encoded as themselves, but the non-printable characters from 0-31 are encoded as the value 188+the character. So for example, a space is encoded as the token “Ġ”. But now the problem is that the UTF8 encoding of “Ġ” is 0xc4 0xa0 when on disk, and so when reading it we have to do just some ugly work to convert this back to a space.

And while none of this is actually that hard to do, it still requires a fair amount of code which is annoying when you're trying to compress everything to be small.

unsigned char a[tmp=999],b[tmp];

for (int i = 0; i < 5e4; i++) {

  int k = i*tmp;

  if (i < 93) {

    // The first 92 tokens are just the printable ascii characters

    bpe[k] = i + 33;

    bpe[k+1] = 0;

  } else if (i > 254) {

    // Ones above 254 are from the BPE file. Load those 

    fscanf(fp, "%s %s", a, b);

    strcat((char*)a, (char*)b);

    int j = 0;

  for (int i = 0; a[i]; i++) {

    // UTF8 encoding makes life hard so handle that here

      bpe[k+j++] = a[i] ^ 196 ? a[i] : a[++i]-128;

    }

    bpe[k+j++] = 0;

  } else if (i > 187) {

    // Tokens above 187 are the nonprintable asii character from 0-32 

    bpe[k] = i-188;

    bpe[k+1] = 0;

  }

}

Conclusion

It's really remarkable how you can distill so many decades of progress in machine learning to just a few thousand bytes. There is essentially nothing missing here from everything you need to run any state-of-the-art neural network (except for the actual model weights). While I mostly put this together for fun, it's a nice demonstration how simple neural networks actually are.




If you want to be notified the next time I write something (maybe like this, maybe not, who knows) enter your email address here.
There's also an RSS Feed if that's your thing.