机器学习基础 2.5:在 C 语言中使用宏实现泛型行为Draft

使用像 C 这样的低级语言所带来的理解上的代价是缺乏特性。当编写用于演示目的的代码时,这并不是一个大问题。但当你需要编写一个实际应用程序时,这就成了问题。

C 语言中缺失的一个显著特性是编译时泛型

编译时泛型是一种代码复用方式,它依赖于以下约束条件:

  1. 存在多个实现相同特性的类
  2. 存在仅通过这些类实现的特性来依赖这些类的代码

在本文中,我想向你展示一种在 C 语言中模拟泛型的方法,特别是针对所有逐元素矩阵操作所依赖的操作特性。我们将使用 C 语言中唯一的内置代码复用方法来实现这一点:宏。

什么是C宏?

C宏是通过#define指令创建的。它们最简单的用法如下:

#define AN_IMPORTANT_NUMBER 42

这会将所有出现的AN_IMPORTANT_NUMBER替换为字面量42。你可能会问:这与以下全局定义有何不同?

const int AN_IMPORTANT_NUMBER = 42;
int main() { ... }

实际上,区别不大。但这只是因为例子太简单了。如果我们写一个宏函数呢?

#define SQUARED_MACRO(x) x*x
int square_function(int x) { return x*x; }

这两者之间有什么区别?当你把宏看作稍微复杂一点的查找和替换时,区别就显而易见了。宏由C预处理器计算,并在编译时将一段代码替换为另一段代码。例如:

int x = SQUARED_MACRO(2);

会被展开为:

int x = 2*2;

而函数则不会被展开,只会生成一个调用指令。宏有几个优点。首先,它没有函数调用的开销。其次,它没有类型限制,因此SQUARED_MACRO可以用于任何支持乘法的类型。

宏的这些优点也带来了一些危险的陷阱。让我们看看SQUARED_MACRO在一些有趣的情况下会展开成什么。

SQUARED_MACRO(2 + 3)
2 + 3*2 + 3

我们希望得到25,但实际得到11。这不太好。不过我们可以轻松解决这个问题。

#define SQUARED_MACRO (x)*(x)

另一个例子:

int x = 1, sum = 0;
while (x < 10) { sum += SQUARED_MACRO(x++); }

会被展开为:

int x = 1, sum = 0;
while (x < 10) { sum += (x++)*(x++); }

我们看到x被递增了两次,而不是一次!此外,它计算的是x*(x+1),而不是x*x

因此,我们看到这个简单的宏在某些输入下会输出完全无意义的结果,而这些输入在编译器看来是有效的。尽管宏重复代码的特性导致了这些危险行为,但它也让我们能够更高效地编写代码。

## 库的目标

我希望这个库是极简的(仅头文件),但功能完备。我们将编写用于两个矩阵 之间各种操作的函数。

我将必要的函数按以下方式分类。

  • 分配类型:我们应该将结果放入缓冲区 ,还是将其写入第一个参数
  • 循环类型:我们是进行点积类型的循环,还是逐元素操作?我们是否应该转置第二个参数
  • 逐元素操作类型:是乘法?加法?减法?

注意,所有函数(除了点积)都对操作 是通用的。

样板代码

在开始之前,我们需要定义我们的矩阵类型。我还将定义一个名为 mfloat 的类型,稍后我们可以将其替换为任何浮点类型。

#define DEBUG 1

typedef double mfloat;

typedef struct {
  mfloat *buf;
  int rows;
  int cols;
} matrix;

我们假设元素以行优先顺序存储在 buf 中。现在,我们需要一些基本的获取和设置函数,这些函数带有可选的边界检查,我们将在库的其余部分中使用它们。

static inline mfloat matrix_get(matrix m, int row, int col) {
  if (DEBUG)
    if (!(row >= 0 && col >= 0 && row < m.rows && col < m.cols)) {
      fprintf(
          stderr,
          "matrix_get: 索引越界 (%d, %d) 对于矩阵大小 (%d, %d)\n",
          row, col, m.rows, m.cols);
      exit(1);
    }

  return m.buf[row * m.cols + col];
}

static inline void matrix_set(matrix m, int row, int col,
                              mfloat val) {
  if (DEBUG)
    if (!(row >= 0 && col >= 0 && row < m.rows && col < m.cols)) {
      fprintf(
          stderr,
          "matrix_set: 索引越界 (%d, %d) 对于矩阵大小 (%d, %d)\n",
          row, col, m.rows, m.cols);
      exit(1);
    }

  m.buf[row * m.cols + col] = val;
}

我们还需要一种将矩阵打印到控制台的方法

static inline void matrix_print(matrix m) {
  for (int i = 0; i < m.rows; i++) {
    printf("[ ");
    for (int j = 0; j < m.cols; j++) {
      // 科学计数法,四舍五入到小数点后四位
      printf("%.4e", matrix_get(m, i, j));
      printf(" ");
    }
    printf("]\n");
  }
  printf("\n");
}

以及一种在堆上分配和释放矩阵的方法

matrix matrix_new(int rows, int cols) {
  double *buf = calloc(rows * cols, sizeof(double));
  if (buf == NULL) {
    printf("matrix_new: calloc 失败。");
    exit(1);
  }

  return (matrix){
      .buf = buf,
      .rows = rows,
      .cols = cols,
  };
}

点积

让我们从点积开始。我只会编写一个分配缓冲区的版本,因为只有在两个矩阵都是方阵的情况下才能将输出写入第一个参数,而我们不能假设这一点。

static inline void matrix_dot(matrix out, const matrix m1,
                              const matrix m2) {
  if (DEBUG)
    if (m1.cols != m2.rows) {
      printf(
          "矩阵点积:维度错误 (%d, %d) 不兼容 (%d, %d)\n",
          m1.rows, m1.cols, m2.rows, m2.cols);
      exit(1);
    }
  for (int row = 0; row < m1.rows; row++) {
    for (int col = 0; col < m2.cols; col++) {
      double sum = 0.0;
      for (int k = 0; k < m1.cols; k++) {
        double x1 = matrix_get(m1, row, k);
        double x2 = matrix_get(m2, k, col);
        sum += x1 * x2;
      }
      matrix_set(out, row, col, sum);
    }
  }
}

逐元素操作

每个逐元素操作执行以下步骤:

  1. 确保两个矩阵具有相同的维度
  2. 对每个对应元素执行操作
  3. 将结果放入输出矩阵

我们发现每个函数的循环部分都是相同的,因此我们可以为此编写一个宏

#define MAT_ELEMENTWISE_LOOP        \
  for (int i = 0; i < m1.rows; i++) \
    for (int j = 0; j < m1.cols; j++)

以及一个检查边界的函数,如果边界不匹配则直接报错。

static inline void mat_bounds_check_elementwise(const matrix out,
                                                const matrix m1,
                                                const matrix m2) {
  if (DEBUG)
    if (m1.rows != m2.rows || m1.cols != m2.cols ||
        out.rows != m1.rows || out.cols != m1.cols) {
      fprintf(stderr,
              "逐元素操作的维度不兼容 "
              "(%d, %d) & (%d, %d) => (%d, %d) \n",
              m1.rows, m1.cols, m2.rows, m2.cols, out.rows,
              out.cols);
      exit(1);
    }
}

现在,我们想要实现加法、乘法、除法和减法。由于除了实际计算之外的代码都是相同的,我们可以将其抽象为一个宏来定义函数。

#define DEF_MAT_ELEMENTWISE_BUF(opname, op)           \
  static inline void matrix_##opname(                 \
      matrix out, const matrix m1, const matrix m2) { \
    mat_bounds_check_elementwise(out, m1, m2);        \
    MAT_ELEMENTWISE_LOOP {                            \
      mfloat x = matrix_get(m1, i, j);                \
      mfloat y = matrix_get(m2, i, j);                \
      matrix_set(out, i, j, op);                      \
    }                                                 \
  }

##opnameopname 的值插入到函数名中。

展望未来,我们知道我们需要为所有变体定义加法、乘法、除法和减法函数,因此让我们编写一个宏来为给定的函数定义宏执行此操作

#define DEF_ALL_OPS(OP_MACRO) \
  OP_MACRO(sub, (x - y));     \
  OP_MACRO(add, (x + y));     \
  OP_MACRO(div, (x / y));     \
  OP_MACRO(mul, (x * y));

现在我们可以实际定义这4个函数了!

DEF_ALL_OPS(DEF_MAT_ELEMENTWISE_BUF)

砰!在一行代码中,我们定义了 matrix_addmatrix_submatrix_divmatrix_mul 函数。

现在让我们尝试实现原地操作。

static inline void mat_bounds_check_elementwise_ip(
    matrix m1, const matrix m2) {
  if (DEBUG)
    if (m1.rows != m2.rows || m1.cols != m2.cols) {
      fprintf(stderr,
              "原地逐元素操作的维度不兼容 "
              "(%d, %d) & (%d, %d) \n",
              m1.rows, m1.cols, m2.rows, m2.cols);
      exit(1);
    }
}

#define DEF_MAT_ELEMENTWISE_IP(opname, op)                 \
  static inline void matrix_ip_##opname(matrix m1,         \
                                        const matrix m2) { \
    mat_bounds_check_elementwise_ip(m1, m2);               \
    MAT_ELEMENTWISE_LOOP {                                 \
      mfloat x = matrix_get(m1, i, j);                     \
      mfloat y = matrix_get(m2, i, j);                     \
      matrix_set(m1, i, j, op);                            \
    }                                                      \
  }

有了这个新宏,我们只需执行

DEF_ALL_OPS(DEF_MAT_ELEMENTWISE_IP)

然后 matrix_ip_addmatrix_ip_mul 等函数就被定义了。我们也可以为转置操作执行此操作,但这里不再展示。

一元操作

有时我们希望对矩阵 进行一元操作,例如标量操作 。让我们创建一个通用的函数来处理这些操作。

#define DEF_MAT_UNARY_IP(opname, op)            \
  static inline void matrix_ip_##opname(matrix m1) { \
    MAT_ELEMENTWISE_LOOP {                           \
      mfloat x = matrix_get(m1, i, j);               \
      matrix_set(m1, i, j, op);                      \
    }                                                \
  }

DEF_MAT_UNARY_IP(square, (x * x))
DEF_MAT_UNARY_IP(negate, (-x))
DEF_MAT_UNARY_IP(sqrt, (sqrt(x)))

现在 matrix_ip_square(A)matrix_ip_negate(A) 等函数已经定义好了。

结论

就这样,我们以最小的努力“编写”了一个完整的矩阵运算库。但这引出了一个问题:这段代码安全吗?

只要你 #undef 所有定义的宏,那么是的,它和你自己编写所有函数一样安全。另一方面,如果你将宏作为库功能的一部分暴露出去,它可能就不再安全了。但仅仅因为某些代码是安全的,并不意味着你应该这样做。如果代码出现问题,可能很难调试,因为你实际上无法看到它被扩展成什么样子。因此,在生产环境中,强烈不建议使用宏。我在这里使用它,只是因为它是教育系列中的一个简便方法。

如果你有任何问题或建议,欢迎留言或给我发邮件。感谢阅读。

✦ No LLMs were used in the ideation, research, writing, or editing of this article.