机器学习项目中的配置文件管理
相信接触过机器学习项目,尤其是从零开始搭建一个代码库的同学们,通常会遇到一个问题:如何管理配置文件。在炼丹过程中,我们会对许多参数进行调整,例如学习率、模型结构、优化器参数等等。这些参数的调整不仅会影响到模型的性能,也会影响到代码的可读性和可维护性。如果直接将这些参数硬编码在代码中,那么每次调整参数都需要修改代码并重新运行,这无疑会增加我们的工作量。并且,随着修改参数的次数增多,我们可能会忘记之前的参数设置,从而导致无法复现之前的实验结果。因此,通常我们会将这些参数保存在配置文件中。将配置文件与代码分离,不仅可以提高代码的可读性和可维护性,还可以方便地管理参数,复现实验结果。
常见的配置文件格式
在 Python 项目中,我们通常会使用 JSON
、YAML
、INI
等格式的配置文件。
例如,著名的目标检测库 MaskRCNN-Benchmark 以及 YOLO 都使用了 YAML
格式的配置文件。YAML
是一种人类可读的数据序列化格式,它的语法简洁明了,适合用来编写配置文件。例如,下面是一个 MaskRCNN-Benchmark 中的配置文件 faster-rcnn-R50-C4.yaml 示例:
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN:
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TEST: 1000
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
SOLVER:
BASE_LR: 0.01
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 8
读取这类配置文件的代码通常会使用 PyYAML
库,例如:
import yaml
cfg_file = 'faster-rcnn-R50-C4.yaml'
with open(cfg_file, 'r') as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
通过以上代码读取 YAML
格式的配置文件,我们可以得到一个字典 cfg
,其中包含了配置文件中的所有参数。
print(cfg)
# 输出:
# {'MODEL': {'META_ARCHITECTURE': 'GeneralizedRCNN', 'WEIGHT': 'catalog://ImageNetPretrained/MSRA/R-50', 'RPN': {'PRE_NMS_TOP_N_TEST': 6000, 'POST_NMS_TOP_N_TEST': 1000}}, 'DATASETS': {'TRAIN': '("coco_2014_train", "coco_2014_valminusminival")', 'TEST': '("coco_2014_minival",)'}, 'SOLVER': {'BASE_LR': 0.01, 'WEIGHT_DECAY': 0.0001, 'STEPS': '(120000, 160000)', 'MAX_ITER': 180000, 'IMS_PER_BATCH': 8}}
但是,使用这样的字典结构来管理配置文件,存在一些问题。例如 YAML
通常习惯采用大写字母来表示键,这样以来在 python 代码中会出现许多诸如 cfg['MODEL']['META_ARCHITECTURE']
这样的代码,不仅不够直观,而且容易出错。因此,我们通常会在代码中将配置文件中参数转换为类似 cfg.MODEL.META_ARCHITECTURE
的形式,以提高代码的可读性。由于 yaml
包并不支持这种形式的读取,因此我们通常会使用 yacs 等额外的包来实现这一功能。
但是,当我们涉及到更复杂的配置文件操作时,例如基础配置文件、模型配置文件、数据集配置文件互相引用或合并时,以及配置文件中定义变量、函数等时,这种简单的配置文件格式就显得力不从心了。因此,我们需要更强大的配置文件管理工具。
使用 PJTOOLS 一站式管理配置文件
笔者在使用 OpenMMLab 系列的开源代码库,例如 MMDetection 时,第一次接触到直接使用 python
文件作为配置文件的方式。
例如,一个简单的 detr-r101 的配置文件可以定义如下:
_base_ = './detr_r50_8xb2-500e_coco.py'
model = dict(
backbone=dict(
depth=101,
init_cfg=dict(type='Pretrained',
checkpoint='torchvision://resnet101')))
这带来了几个好处,例如:1. 直接使用 python 语法,对 python 项目更加友好; 2. 可以直接引用其它配置文件,方便复用; 3. 可以直接使用 python 语法定义变量、函数等,使得我们可以在配置文件中实现更复杂的逻辑。
但是,OpenMMLab 系列代码库的配置文件功能是集成在 MMEngine 中的。如果专门为了配置文件管理引入 MMEngine,可能会略显臃肿。
为此,我实现了一个简洁轻量的包 project-tools,简称 pjtools
。目的是为了提供一些机器学习项目中可能会使用到的工具,例如配置文件管理、日志管理等。其中,配置文件管理是 pjtools
中的一个重要功能。
安装起来非常简单:
pip install pjtools
在 pjtools
中,我实现了 AutoConfigurator
类,其支持常用的 YAML
, JSON
, 和 Python
格式的配置文件,并采用了统一的接口进行封装。例如,我们可以使用 AutoConfigurator
来读取不同格式的配置文件:
from pjtools.configurator.configurator import AutoConfigurator
config_file = 'config.yaml'
# config_file = 'config.json'
# config_file = 'config.py'
auto_config = AutoConfigurator.fromfile(config_file)
并且,对于 python
格式的配置文件,我们可以通过在配置文件中指定 _base_
来引用其它配置文件:
# default.py
learning_rate = 0.01
momentum = 0.9
optimizer = 'Adam'
use_cuda = True
# model.py
_base_ = ['tests/data/default.py']
learning_rate = 0.001
通过引用 _base_
,我们可以将 model.py
中的配置文件与 default.py
中的配置文件合并,从而实现配置文件的复用。如果存在相同的键,后者会覆盖前者。
通过这种形式,我们甚至可以实现更加复杂的配置文件操作,例如定义类、变量、函数等。例如,以下配置文件来自我近期一篇论文 ModaVerse: Efficiently Transforming Modalities with LLMs 中的 base.py
配置文件:
from typing import List
import torch
from peft import LoraConfig, TaskType
from transformers import StoppingCriteria, StoppingCriteriaList
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops: List = None, encounters: int = 1):
super().__init__()
self.stops = stops
self.ENCOUNTERS = encounters
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
stop_count = 0
for stop in self.stops:
_stop = torch.tensor(stop).to(input_ids[0].device)
indices = torch.where(_stop[0] == input_ids)
for i in indices:
if len(i) > 0:
if torch.all(input_ids[0][i:i + len(_stop)] == _stop):
stop_count += 1
if stop_count >= self.ENCOUNTERS:
return True
return False
prompt_configs = dict(path='assets/prompts/prompt_template.txt',
media_placeholder='{media}',
instruction_placeholder='{instruction}')
model_configs = dict(
name='ModaVerse-7b',
imagebind=dict(hidden_size=1024),
foundation_llm=dict(type='vicuna-7b', checkpoint='.checkpoints/7b_v0'),
modaverse=dict(
max_length=512,
modality_begin_token='<Media>',
modality_end_token='</Media>',
modality_flags=['[TEXT]', '[IMAGE]', '[AUDIO]', '[VIDEO]'],
target_padding=-100,
top_p=0.01,
temperature=1,
max_new_tokens=246,
do_sample=True,
use_cache=True,
stopping_token=835,
stopping_criteria=StoppingCriteriaList(
[StoppingCriteriaSub(stops=[[835]], encounters=1)], ),
generator=dict(
image_diffuser=dict(
type='stable_diffusion',
# preload=False,
cfgs=dict(model='runwayml/stable-diffusion-v1-5')),
video_diffuser=dict(
type='damo_vilab',
# preload=False,
cfgs=dict(model='damo-vilab/text-to-video-ms-1.7b')),
audio_diffuser=dict(type='audio_ldm',
cfgs=dict(model='cvssp/audioldm-l-full')),
),
))
training_configs = dict(
lora_config=LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=32,
lora_alpha=32,
lora_dropout=0.1,
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']),
deepspeed_cfg=dict(path='configs/dscfg.json', backend='nccl'),
saving_root='./experiments',
epochs=1,
warmup_rate=0.1,
force_training_layers=['embed_tokens.weight', 'lm_head.weight'],
report_backend=dict(type='wandb', iterval=10),
print_prediction=dict(turn_on=True, interval=1000),
checkpointer=dict(type='iteration', interval=5000))
dataset_configs = dict(train=dict(instruction_path='dataset/instructions.json',
media_root='dataset/'))
可以看到,通过使用 python
格式的配置文件,我们可以定义类、变量、函数等,实现更加复杂的配置文件操作。而 pjtools
的 AutoConfigurator
类可以帮助我们统一管理这些配置文件,提高代码的可读性和可维护性。
Leave a comment