龙空技术网

训练时有多少个随机种子需要设置?

技术债 690

前言:

而今同学们对“python的随机数种子怎么用”都比较关切,我们都想要剖析一些“python的随机数种子怎么用”的相关资讯。那么小编同时在网上搜集了一些对于“python的随机数种子怎么用””的相关知识,希望我们能喜欢,同学们一起来了解一下吧!

训练时,为了实现结果有参考性,我们需要尽量在相同的环境(CPU、GPU、内存、数据等)下进行实验,所以每个epoch 的大小,数据都需要一致,则每次实验的随机种子也需要一致。

一共有几种随机种子?python自带的随机种子:random.seed()numpy的随机种子:numpy.random.seed()训练框架的随机种子:pytroch:torch.manual_seed() or torch.cuda.manual_seed_all()tensorflow: tensorflow.set_random_seed()

... 其他种类的卡(略)

当然如果你没有用到的numpy,可以不用设置numpy的随机种子,最好都设置下,指不定哪个第三方库用到。

注意:在多进程情况下,fork出来的进程,随机种子可能是跟父进程的一致的,若不想每个进程的随机种子都一样,每个进程都需要单独设置随机种子

如果你刚好使用的是pytorch, 可以使用 pytorch-lightning 这个库, 一行代码,不过需要依赖一个库,看大家的取舍。

import pytorch_lightning as plseed = 42pl.seed_everything(seed)

也可以参考它的代码实现:

def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:    """Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,    sets the following environment variables:    - `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend).    - `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``.    Args:        seed: the integer value seed for global random state in Lightning.            If `None`, will read seed from `PL_GLOBAL_SEED` env variable            or select it randomly.        workers: if set to ``True``, will properly configure all dataloaders passed to the            Trainer with a ``worker_init_fn``. If the user already provides such a function            for their dataloaders, setting this argument will have no influence. See also:            :func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`.    """    max_seed_value = np.iinfo(np.uint32).max    min_seed_value = np.iinfo(np.uint32).min    if seed is None:        env_seed = os.environ.get("PL_GLOBAL_SEED")        if env_seed is None:            seed = _select_seed_randomly(min_seed_value, max_seed_value)            rank_zero_warn(f"No seed found, seed set to {seed}")        else:            try:                seed = int(env_seed)            except ValueError:                seed = _select_seed_randomly(min_seed_value, max_seed_value)                rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}")    elif not isinstance(seed, int):        seed = int(seed)    if not (min_seed_value <= seed <= max_seed_value):        rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")        seed = _select_seed_randomly(min_seed_value, max_seed_value)    # using `log.info` instead of `rank_zero_info`,    # so users can verify the seed is properly set in distributed training.    log.info(f"Global seed set to {seed}")    os.environ["PL_GLOBAL_SEED"] = str(seed)    random.seed(seed)    np.random.seed(seed)    torch.manual_seed(seed)    torch.cuda.manual_seed_all(seed)    os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"    return seed

在worker_init_fn设置每个worker的seed

DataLoader(    train_dataset,    batch_size=batch_size,    num_workers=num_workers,    worker_init_fn=pl_worker_init_function,    ...)def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:  # pragma: no cover    """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed    with ``seed_everything(seed, workers=True)``.    See also the PyTorch documentation on    `randomness in DataLoaders <;`_.    """    # implementation notes: <;    global_rank = rank if rank is not None else rank_zero_only.rank    process_seed = torch.initial_seed()    # back out the base seed so we can use all the bits    base_seed = process_seed - worker_id    log.debug(        f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}"    )    ss = np.random.SeedSequence([base_seed, worker_id, global_rank])    # use 128 bits (4 x 32-bit words)    np.random.seed(ss.generate_state(4))    # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module    torch_ss, stdlib_ss = ss.spawn(2)    torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])    # use 128 bits expressed as an integer    stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()    random.seed(stdlib_seed)

每个进程的seed都固定下来,确保每次都保持一致

标签: #python的随机数种子怎么用