Baremetal-NN
Baremetal-NN API documentation
Loading...
Searching...
No Matches
tensor.h
Go to the documentation of this file.
1#ifndef __NN_TENSOR
2#define __NN_TENSOR
3
4#include <stddef.h>
5#include <stdint.h>
6#include <stdlib.h>
7#include <stdio.h>
8#include <assert.h>
9#include <string.h>
10
11#include "float16.h"
12
13#define MAX_DIMS 4
14
15// #define F32(ptr) (*((float *)(ptr)))
16// #define F64(ptr) (*((double *)(ptr)))
17// #define I8(ptr) (*((int8_t *)(ptr)))
18// #define I16(ptr) (*((int16_t *)(ptr)))
19
20
21typedef enum {
28 // DTYPE_I64,
29 // DTYPE_I128,
30 // DTYPE_F8,
33 // DTYPE_F64,
34} DataType;
35
36typedef struct {
38 size_t ndim;
39 size_t size;
40 size_t shape[MAX_DIMS];
41 void *data;
42} Tensor;
43
44
45static inline size_t NN_sizeof(DataType dtype) {
46 switch (dtype) {
47 case DTYPE_U8:
48 return sizeof(uint8_t);
49 case DTYPE_I8:
50 return sizeof(int8_t);
51 case DTYPE_U16:
52 return sizeof(uint16_t);
53 case DTYPE_I16:
54 return sizeof(int16_t);
55 case DTYPE_U32:
56 return sizeof(uint32_t);
57 case DTYPE_I32:
58 return sizeof(int32_t);
59 // case DTYPE_I64:
60 // return sizeof(int64_t);
61 case DTYPE_F16:
62 return sizeof(float16_t);
63 case DTYPE_F32:
64 return sizeof(float);
65 // case DTYPE_F64:
66 // return sizeof(double);
67 default:
68 printf("[WARNING] Unsupported data type: %d\n", dtype);
69 return 0;
70 }
71}
72
73static inline const char *NN_get_datatype_name(DataType dtype) {
74 switch (dtype) {
75 case DTYPE_U8:
76 return "UINT8";
77 case DTYPE_I8:
78 return "INT8";
79 case DTYPE_U16:
80 return "UINT16";
81 case DTYPE_I16:
82 return "INT16";
83 case DTYPE_U32:
84 return "UINT32";
85 case DTYPE_I32:
86 return "INT32";
87 // case DTYPE_I64:
88 // return "INT64";
89 case DTYPE_F16:
90 return "FLOAT16";
91 case DTYPE_F32:
92 return "FLOAT32";
93 // case DTYPE_F64:
94 // return "FLOAT64";
95 default:
96 return "UNKNOWN";
97 }
98}
99
107static inline uint8_t NN_is_scalar(Tensor *tensor) {
108 return tensor->ndim == 0;
109}
110
116static inline uint8_t NN_is_vector(Tensor *tensor) {
117 return tensor->ndim == 1;
118}
119
125static inline uint8_t NN_is_matrix(Tensor *tensor) {
126 return tensor->ndim == 2;
127}
128
134static inline uint8_t NN_is_3d(Tensor *tensor) {
135 return tensor->ndim == 3;
136}
137
143static inline uint8_t NN_is_4d(Tensor *tensor) {
144 return tensor->ndim == 4;
145}
146
152static inline void NN_free_tensor_data(Tensor *tensor) {
153 free(tensor->data);
154}
155
161static inline void NN_delete_tensor(Tensor *tensor) {
162 free(tensor);
163}
164
165
166
167#endif // __NN_TENSOR
uint16_t float16_t
Definition: float16.h:21
size_t ndim
Definition: tensor.h:38
DataType dtype
Definition: tensor.h:37
void * data
Definition: tensor.h:41
size_t size
Definition: tensor.h:39
Definition: tensor.h:36
static void NN_delete_tensor(Tensor *tensor)
Definition: tensor.h:161
static void NN_free_tensor_data(Tensor *tensor)
Definition: tensor.h:152
#define MAX_DIMS
Definition: tensor.h:13
static uint8_t NN_is_matrix(Tensor *tensor)
Definition: tensor.h:125
static uint8_t NN_is_3d(Tensor *tensor)
Definition: tensor.h:134
static size_t NN_sizeof(DataType dtype)
Definition: tensor.h:45
static const char * NN_get_datatype_name(DataType dtype)
Definition: tensor.h:73
static uint8_t NN_is_scalar(Tensor *tensor)
Definition: tensor.h:107
DataType
Definition: tensor.h:21
@ DTYPE_U8
Definition: tensor.h:22
@ DTYPE_F32
Definition: tensor.h:32
@ DTYPE_I8
Definition: tensor.h:23
@ DTYPE_I32
Definition: tensor.h:27
@ DTYPE_I16
Definition: tensor.h:25
@ DTYPE_U32
Definition: tensor.h:26
@ DTYPE_U16
Definition: tensor.h:24
@ DTYPE_F16
Definition: tensor.h:31
static uint8_t NN_is_vector(Tensor *tensor)
Definition: tensor.h:116
static uint8_t NN_is_4d(Tensor *tensor)
Definition: tensor.h:143