grpo_server: GRPO for fun and agents and profit

TL;DR: I’m releasing a new open-source package grpo_server. Only of interest if you want to play around with teaching language models to do things.

After doing the first experiment with grpo, it occurred to me that it would be cool to reverse the way the code works.

Right now, the trl grpo implementation is pretty conventional: it is given a prompt from a dataset, and then it calls a list of reward functions to evaluate its responses.

What if this was done a bit differently? Instead of having the prompt and the reward be too far from each other, it would be nice to be able to have a closer connection to it. Thus, the idea of grpo_server was born (the package is under apache 2.0; not yet in pypi).

It’s a network (http) server with two operations (+ some management stuff):

* get completions for this prompt

* give feedback (here are some completions, this is how much reward each of them got)

What this allows is that

1) you can write your "task to be learned" code in a straightforward way (generate prompt -> get completions -> compute rewards)

2) you can run multiple such tasks at the same time on the same underlying neural network

3) you can run the neural network on a fast cloud machine with lots of CPUs and have your local machine run the prompt and reward scripts

4) you can use it with agents -- anything where you get feedback for multiple possible actions of an earlier stage is fair game

5) you can actually also implement DPO-like stuff: the completions need not be generated by the model! (not yet implemented,
  currently the tokens are passed back and forth but not hard to hack if you take care of the tokenization yourself)

There’s an example that shows that it’s possible to teach a network (smollm) to give the response to a request to alphabetize a list in a specific format.

This version I’m releasing now is a VERY rough implementation. Be prepared to get your hands dirty. It’s still missing vital pieces (e.g. peft; that’s easy to add but I haven’t had time yet, I’ve only tested this with smollm so far)

But let’s see if someone figures out how to do something cool with it :) If you do, please get in touch.

Patches accepted. YMMV. Caveat emptor.




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Why you should write Jax functions without broadcasting
  • Prediction: source code is going the way of assembly language
  • Smol models can add (better). First experiment in LLM finetuning
  • The mypy bug that just kept on giving
  • Getting into coding with LLMs; Nebius onboarding is lovely