Using GRPOTrainer with a custom PyTorch module?

Hello! I was wondering what changes I would need to make to use GRPOTrainer with a custom PyTorch module class.

Currently I have an nn.Module subclass that wraps around an existing Huggingface transformer, except with a custom forward and generate function.

I was wondering though if there were resources on either converting an nn.Module to a transformer to be used with Trainer, or what other functionality I would need to implement as well as changes I’d need to make to my forward and generate methods to work with

If you inherit PreTrainedModel, you should have most of the necessary functions. As for Trainer, it seems that you can modify loss functions, gradient-related functions, and so on.

Sorry for bumping but do you know if Trainer supports passing in inputs_embeds for generate yet? I ended up just monkey-patching the generate method to use my own generation code, but it seems that when I directly pass in inputs_embeds instead of input_ids into the original generate method, I get the following error:

Traceback (most recent call last):
  File "test_grpo.py", line 124, in <module>
    trainer.train()
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 2241, in train
    return inner_training_loop(
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 3692, in training_step
    inputs = self._prepare_inputs(inputs)
  File "env/lib64/python3.9/site-packages/trl/trainer/grpo_trainer.py", line 576, in _prepare_inputs
    eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
IndexError: argmax(): Expected reduction dim 1 to have non-zero size.
Traceback (most recent call last):
  File "test_grpo.py", line 124, in <module>
    trainer.train()
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 2241, in train
    return inner_training_loop(
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "env/lib64/python3.9/site-packages/transformers/trainer.py", line 3692, in training_step
    inputs = self._prepare_inputs(inputs)
  File "env/lib64/python3.9/site-packages/trl/trainer/grpo_trainer.py", line 576, in _prepare_inputs
    eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
IndexError: argmax(): Expected reduction dim 1 to have non-zero size.

fixed. turns out i just need to pad back the output with the original prompt lol