- 前向传播
#include <vector>
std::vector<at::Tensor> lltm_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
auto X = torch::cat({old_h, input}, /*dim=*/1);
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
auto gates = gate_weights.chunk(3, /*dim=*/1);
auto input_gate = torch::sigmoid(gates[0]);
auto output_gate = torch::sigmoid(gates[1]);
auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);
auto new_cell = old_cell + candidate_cell * input_gate;
auto new_h = torch::tanh(new_cell) * output_gate;
return {new_h,
new_cell,
input_gate,
output_gate,
candidate_cell,
X,
gate_weights};
}
- 反向传播
// tanh'(z) = 1 - tanh^2(z)
torch::Tensor d_tanh(torch::Tensor z) {
return 1 - z.tanh().pow(2);
}
// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
auto e = z.exp();
auto mask = (alpha * (e - 1)) < 0;
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}
std::vector<torch::Tensor> lltm_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
auto d_output_gate = torch::tanh(new_cell) * grad_h;
auto d_tanh_new_cell = output_gate * grad_h;
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
auto d_old_cell = d_new_cell;
auto d_candidate_cell = input_gate * d_new_cell;
auto d_input_gate = candidate_cell * d_new_cell;
auto gates = gate_weights.chunk(3, /*dim=*/1);
d_input_gate *= d_sigmoid(gates[0]);
d_output_gate *= d_sigmoid(gates[1]);
d_candidate_cell *= d_elu(gates[2]);
auto d_gates =
torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
auto d_weights = d_gates.t().mm(X);
# d_gates是根据当前cell的操作,如sigmoid计算得到的梯度,这个梯度梯度乘当前的输入就是完整权重矩阵的梯度
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
auto d_X = d_gates.mm(weights);
# d_X 是X的梯度
const auto state_size = grad_h.size(1);
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
}
关于如何计算得到d_output_gate(output_gate梯度),d_new_cell(new_cell梯度)
链式求导法则,对于d_output_gate,假设L是loss
前向传播公式为:
auto new_h = torch::tanh(new_cell) * output_gate;
反向传播则为:
d(L)/d(new_h) * d(new_h)/d(output_gate) = grad_h*torch::tanh(new_cell)同理,对于d_new_cell,
前向传播为
auto new_cell = old_cell + candidate_cell * input_gate;
auto new_h = torch::tanh(new_cell) * output_gate;
反向传播则为:
d(L)/d(new_h)*d(new_h)/d(new_cell) = grad_h *(d_tanh(new_cell)*output_gate)
这里为什么 加上 grad_cell ,我还不知道,
个人理解是new_cell 用到了old_cell, 也可以看作,把new_cell带入torch::tanh(new_cell)中,