Huggingface Transformers Trainer as a general PyTorch trainer

这篇 启发, 自定义 Huggingface Transformers Trainer 做通用训练器.

模型定义照常.

import torch.nn as nn

class Model(nn.Module):
    def forward(self, inputs):
        ...
        return logits

自定义损失函数. 损失函数要么写在模型的 forward 里 (Huggingface 的写法), 要么继承 Trainer 类, 覆写 compute_loss.

import transformers

class MyTrainer(transformers.Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop('labels')
        logits = model(**inputs)
        # loss_fct = nn.CrossEntropyLoss()
        loss_fct = nn.BCEWithLogitsLoss()
        loss = loss_fct(logits, labels)
        # TODO: tested only with `return_outputs=False`
        return (loss, {'logits': logits}) if return_outputs else loss

定义指标.

def compute_metrics(eval_pred: [transformers.EvalPrediction]) -> Dict:
    ...
    return {
        'f1': ...,
    }

接下来初始化模型, 自定义 torch.utils.data.Dataset 和 collate_fn (trainer 里会自动组装为 dataloader).

class MyDataset(torch.utils.data.Dataset):
    ...
    def __getitem__(self, idx):
        return {
            'inputs': ...,
            'labels': ...,
        }

指定训练参数, 要注意 label_names. 看 Trainer 的 prediction_step 方法源码.

# trainer.py
class Trainer:
    def __init__(self, ...):
        ...
        default_label_names = find_labels(self.model.__class__)
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
        
    def prediction_step(self, ...):
        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
        ...
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels or loss_without_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None
        ...
        return (loss, logits, labels)
            
        
# utils/generic.py
def find_labels(model_class):
    model_name = model_class.__name__
    if model_name.startswith("TF"):
        signature = inspect.signature(model_class.call)
    elif model_name.startswith("Flax"):
        signature = inspect.signature(model_class.__call__)
    else:
        signature = inspect.signature(model_class.forward)
    if "QuestionAnswering" in model_name:
        return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
    else:
        return [p for p in signature.parameters if "label" in p]

需要根据 label_names 拿 label, 所以自定义模型需要在 TrainingArguments 显示指定. 以上述自定义 dataset 为例, 令 label_names=['labels'], 否则 trainer 无法获得标签.

之后照常使用 Trainer.

其他自定义

官方文档 中写到了可以覆写哪些方法.

优化器

覆盖默认优化器, 使用 Adam, 不用 scheduler.

class MyTrainer(transformers.Trainer):
    def create_optimizer(self):
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.args.learning_rate,
        )

    def create_scheduler(self, num_training_steps, optimizer=None):
        # overwrites the default scheduler and does nothing instead
        self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda x: 1,
        )

集成 MLflow

如果安装了 mflow, 直接用 Trainer 时会自动使用 MLflowCallback. 一个 mlfow run 会自动记录训练开始到结束的指标. 如果想在训练结束后继续用 mlflow 记录, 可以如下操作.

先看源码

# integrations.py
class MLflowCallback(TrainerCallback):
    def __init__(self):
        ...
        import mlflow
        self._ml_flow = mlflow
        self._auto_end_run = False
        self._initialized = False
    
    def setup(self, args, state, model):
        ...
        if self._ml_flow.active_run() is None:
            self._ml_flow.start_run(...)
            self._auto_end_run = True
            ...
        self._initialized = True
        
    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_train_end(self, args, state, control, **kwargs):
        if self._initialized and self._auto_end_run and self._ml_flow.active_run():
            self._ml_flow.end_run()

因此只要在 Trainer 外面包一层 mlflow run 即可.

with mlflow.start_run(run_name=training_args.run_name):
    trainer.train()
    results = trainer.evaluate(dataset)
    # 自定义改写指标的名称 (比如加上前缀), 筛选需要记录的指标
    tmp = collate_metrics(results)
    mlflow.log_metrics(tmp)

默认的 MLflowCallback 会记录 TrainingArguments, 至于模型的超参数等别的参数可以写个 callback, 在 on_train_begin 中 log 这些参数即可.

其他

Pytorch Lightning 和 HuggingFace 的 Trainer 哪个好用?

Lightning 提供的自定义更多, HuggingFace 毕竟是对 transformers 优化的. 至于用不用这些高度封装的 trainer:

  • 它们提供的日志很酷炫 (但可能不一定需要这么多功能)
  • 标准化任务可以用, 代码简单清楚, 模块化
  • 分布式训练好用
  • 多任务时 huggingface 难写