Baremetal-NN
Baremetal-NN API documentation
Loading...
Searching...
No Matches

◆ nn_scaled_dot_product_attention_f32()

void nn_scaled_dot_product_attention_f32 ( Tensor4D_F32 y,
const Tensor4D_F32 query,
const Tensor4D_F32 key,
const Tensor4D_F32 value 
)

Computes scaled dot product attention on query, key and value tensors.

Shape legend:

  • N: batch size
  • S: source sequence length
  • L: target sequence length
  • E: embedding dimension of the query and key
  • Ev: embedding dimension of the value
  • H: number of attention heads
Parameters
yThe output tensor, of shape (N, H, L, Ev).
queryThe query tensor, of shape (N, H, L, E).
keyThe key tensor, of shape (N, H, S, E).
valueThe value tensor, of shape (N, H, S, Ev).