mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-28 18:30:25 +08:00
* indextts2 * update lfs for audio files --------- Co-authored-by: wangyining02 <wangyining02@bilibili.com>
104 lines
3.1 KiB
Python
104 lines
3.1 KiB
Python
# Copyright (c) 2023 Amphion.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import os, sys
|
|
import os.path as osp
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.optim import Optimizer
|
|
from functools import reduce
|
|
from torch.optim import AdamW
|
|
|
|
|
|
class MultiOptimizer:
|
|
def __init__(self, optimizers={}, schedulers={}):
|
|
self.optimizers = optimizers
|
|
self.schedulers = schedulers
|
|
self.keys = list(optimizers.keys())
|
|
self.param_groups = reduce(
|
|
lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()]
|
|
)
|
|
|
|
def state_dict(self):
|
|
state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys]
|
|
return state_dicts
|
|
|
|
def scheduler_state_dict(self):
|
|
state_dicts = [(key, self.schedulers[key].state_dict()) for key in self.keys]
|
|
return state_dicts
|
|
|
|
def load_state_dict(self, state_dict):
|
|
for key, val in state_dict:
|
|
try:
|
|
self.optimizers[key].load_state_dict(val)
|
|
except:
|
|
print("Unloaded %s" % key)
|
|
|
|
def load_scheduler_state_dict(self, state_dict):
|
|
for key, val in state_dict:
|
|
try:
|
|
self.schedulers[key].load_state_dict(val)
|
|
except:
|
|
print("Unloaded %s" % key)
|
|
|
|
def step(self, key=None, scaler=None):
|
|
keys = [key] if key is not None else self.keys
|
|
_ = [self._step(key, scaler) for key in keys]
|
|
|
|
def _step(self, key, scaler=None):
|
|
if scaler is not None:
|
|
scaler.step(self.optimizers[key])
|
|
scaler.update()
|
|
else:
|
|
self.optimizers[key].step()
|
|
|
|
def zero_grad(self, key=None):
|
|
if key is not None:
|
|
self.optimizers[key].zero_grad()
|
|
else:
|
|
_ = [self.optimizers[key].zero_grad() for key in self.keys]
|
|
|
|
def scheduler(self, *args, key=None):
|
|
if key is not None:
|
|
self.schedulers[key].step(*args)
|
|
else:
|
|
_ = [self.schedulers[key].step_batch(*args) for key in self.keys]
|
|
|
|
|
|
def define_scheduler(optimizer, params):
|
|
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["gamma"])
|
|
|
|
return scheduler
|
|
|
|
|
|
def build_optimizer(model_dict, scheduler_params_dict, lr, type="AdamW"):
|
|
optim = {}
|
|
for key, model in model_dict.items():
|
|
model_parameters = model.parameters()
|
|
parameters_names = []
|
|
parameters_names.append(
|
|
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
|
)
|
|
if type == "AdamW":
|
|
optim[key] = AdamW(
|
|
model_parameters,
|
|
lr=lr,
|
|
betas=(0.9, 0.98),
|
|
eps=1e-9,
|
|
weight_decay=0.1,
|
|
)
|
|
else:
|
|
raise ValueError("Unknown optimizer type: %s" % type)
|
|
|
|
schedulers = dict(
|
|
[
|
|
(key, torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.999996))
|
|
for key, opt in optim.items()
|
|
]
|
|
)
|
|
|
|
multi_optim = MultiOptimizer(optim, schedulers)
|
|
return multi_optim
|