受 这篇 启发, 自定义 Huggingface Transformers Trainer 做通用训练器.
模型定义照常.
import torch.nn as nn
class Model(nn.Module):
def forward(self, inputs):
...
return logits
自定义损失函数. 损失函数要么写在模型的 forward 里 (Huggingface 的写法), 要么继承 Trainer 类, 覆写 compute_loss.
import transformers
class MyTrainer(transformers.Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop('labels')
logits = model(**inputs)
# loss_fct = nn.CrossEntropyLoss()
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
# TODO: tested only with `return_outputs=False`
return (loss, {'logits': logits}) if return_outputs else loss

