models.diffusion.diffusion_policy
Source: models/diffusion/diffusion_policy.py
models.diffusion.diffusion_policy
ActionRefinementBlock
Bases: Module
Refines action sequence with temporal awareness
DiffusionPolicy
Bases: Module
Diffusion Policy model for control tasks. Takes past robot states (vectors) and predicts action sequences. Simplified architecture with fewer, well-supported operators.
denoise_step(past_states, actions, timestep, image=None)
Single denoising step - simplified with fewer operations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
past_states
|
past robot states [batch, num_past_states, state_dim] or flattened |
required | |
actions
|
noisy action sequence [batch, num_actions, action_dim] |
required | |
timestep
|
diffusion timestep [batch] |
required | |
image
|
RGB image [batch, 3, 224, 224] (optional) |
None
|
Returns:
| Type | Description |
|---|---|
|
predicted noise [batch, num_actions, action_dim] |
forward(past_states, actions, timestep, image=None)
Single denoising step - simplified to one iteration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
past_states
|
past robot states [batch, num_past_states, state_dim] or flattened |
required | |
actions
|
noisy action sequence [batch, num_actions, action_dim] |
required | |
timestep
|
diffusion timestep [batch] |
required | |
image
|
RGB image [batch, 3, 224, 224] (optional) |
None
|
Returns:
| Type | Description |
|---|---|
|
predicted noise/action [batch, num_actions, action_dim] |
ImageEncoder
Bases: Module
Simple CNN encoder for RGB images - increased channels for 2-3x more compute
ONNXCompatibleGlobalPool2d
Bases: Module
ONNX-compatible replacement for AdaptiveMaxPool2d(1) and AdaptiveAvgPool2d(1)
SimpleBlock
Bases: Module
Simple block with linear + activation + residual
SinusoidalPositionEmbeddings
Bases: Module
Sinusoidal position embeddings for diffusion timesteps
UNetBlock
Bases: Module
Basic U-Net block with residual connection
replace_adaptive_pooling_with_onnx_compatible(module)
Recursively replace AdaptiveMaxPool2d and AdaptiveAvgPool2d with ONNX-compatible versions