Baremetal-NN
Baremetal-NN API documentation
|
void nn_scaled_dot_product_attention_f32 | ( | Tensor4D_F32 * | y, |
const Tensor4D_F32 * | query, | ||
const Tensor4D_F32 * | key, | ||
const Tensor4D_F32 * | value | ||
) |
Computes scaled dot product attention on query, key and value tensors.
Shape legend:
y | The output tensor, of shape (N, H, L, Ev). |
query | The query tensor, of shape (N, H, L, E). |
key | The key tensor, of shape (N, H, S, E). |
value | The value tensor, of shape (N, H, S, Ev). |