|
Baremetal-NN
Baremetal-NN API documentation
|
| void nn_scaled_dot_product_attention_f32 | ( | Tensor4D_F32 * | y, |
| const Tensor4D_F32 * | query, | ||
| const Tensor4D_F32 * | key, | ||
| const Tensor4D_F32 * | value | ||
| ) |
Computes scaled dot product attention on query, key and value tensors.
nn_scaled_dot_product_attention_f32
Computes: y = softmax((query @ key.transpose(-2, -1)) / sqrt(E)) @ value
Shape legend:
| y | The output tensor, of shape (N, H, L, Ev). |
| query | The query tensor, of shape (N, H, L, E). |
| key | The key tensor, of shape (N, H, S, E). |
| value | The value tensor, of shape (N, H, S, Ev). |