Baremetal-NN
Baremetal-NN API documentation
Loading...
Searching...
No Matches

◆ nn_scaled_dot_product_attention_f32()

void nn_scaled_dot_product_attention_f32 ( Tensor4D_F32 y,
const Tensor4D_F32 query,
const Tensor4D_F32 key,
const Tensor4D_F32 value 
)

Computes scaled dot product attention on query, key and value tensors.

nn_scaled_dot_product_attention_f32

Computes: y = softmax((query @ key.transpose(-2, -1)) / sqrt(E)) @ value

Shape legend:

  • N: batch size
  • H: number of attention heads
  • L: target sequence length (query length)
  • S: source sequence length (key/value length)
  • E: embedding dimension of the query and key
  • Ev: embedding dimension of the value
Parameters
yThe output tensor, of shape (N, H, L, Ev).
queryThe query tensor, of shape (N, H, L, E).
keyThe key tensor, of shape (N, H, S, E).
valueThe value tensor, of shape (N, H, S, Ev).