mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-28 18:30:25 +08:00
* indextts2 * update lfs for audio files --------- Co-authored-by: wangyining02 <wangyining02@bilibili.com>
35 lines
859 B
Python
35 lines
859 B
Python
# Copyright (c) 2023 Amphion.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from torch.autograd import Function
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class GradientReversal(Function):
|
|
@staticmethod
|
|
def forward(ctx, x, alpha):
|
|
ctx.save_for_backward(x, alpha)
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
grad_input = None
|
|
_, alpha = ctx.saved_tensors
|
|
if ctx.needs_input_grad[0]:
|
|
grad_input = -alpha * grad_output
|
|
return grad_input, None
|
|
|
|
|
|
revgrad = GradientReversal.apply
|
|
|
|
|
|
class GradientReversal(nn.Module):
|
|
def __init__(self, alpha):
|
|
super().__init__()
|
|
self.alpha = torch.tensor(alpha, requires_grad=False)
|
|
|
|
def forward(self, x):
|
|
return revgrad(x, self.alpha)
|