CONTENTS

机器学习基础 2.5:使用宏在C语言中实现泛型行为Draft

使用C这类低级语言所带来的理解深度,其代价是语言特性的缺失。在编写演示用途的代码时,这并不构成问题。但当你需要编写实际应用程序时,情况就不同了。

C语言中一个显著缺失的特性是编译时泛型

编译时泛型是一种代码复用机制,它依赖于以下约束条件:

  1. 存在多个实现了同一特征的类
  2. 存在仅通过这些类所实现的特征来依赖它们的代码

在本文中,我想展示一种在C语言中模拟泛型的方法,特别是针对所有逐元素矩阵运算都将依赖的操作特征。我们将使用C语言中唯一内置的代码复用方法来实现:宏。

什么是 C 宏?

C 宏通过 #define 指令创建。最简单的用法如下:

#define AN_IMPORTANT_NUMBER 42

这会将所有出现的 AN_IMPORTANT_NUMBER 替换为字面量 42。你可能想知道:这与下面的全局定义有何不同?

const int AN_IMPORTANT_NUMBER = 42;
int main() { ... }

实际上,区别不大。但这只是因为示例太简单。如果我们写一个宏函数呢?

#define SQUARED_MACRO(x) x*x
int square_function(int x) { return x*x; }

这两者有什么区别?当你把宏看作一种稍复杂的查找替换时,区别就清晰了。宏由 C 预处理器处理,在编译时将一段代码替换为另一段代码。例如:

int x = SQUARED_MACRO(2);

会展开为:

int x = 2*2;

而函数则不会展开,只会生成一条调用指令。宏有几个优点:首先,它没有函数调用的开销;其次,它是无类型的,因此 SQUARED_MACRO 可以用于任何支持乘法的类型。

带来这些优点的宏特性,也同时制造了许多危险的陷阱。让我们看看 SQUARED_MACRO 在一些有趣情况下的展开结果。

SQUARED_MACRO(2 + 3)
2 + 3*2 + 3

我们想要 25,却得到了 11。这不好。不过我们可以轻松解决这个问题。

#define SQUARED_MACRO (x)*(x)

另一个例子:

int x = 1, sum = 0;
while (x < 10) { sum += SQUARED_MACRO(x++); }

会展开为:

int x = 1, sum = 0;
while (x < 10) { sum += (x++)*(x++); }

我们看到 x 被递增了两次,而不是一次!此外,它计算的是 x*(x+1),而不是 x*x

由此可见,这个简单的宏在某些输入下会产生完全错误的结果,而编译器却认为这些输入是有效的。尽管宏重复代码的特性导致了这些危险行为,但它也让我们能实现更高的效率。

库的目标

我希望这个库是极简的(仅头文件),但功能完整。我们将编写一系列用于两个矩阵 之间运算的函数。

我将必要的函数按以下方式分类。

  • 分配类型:结果应该放入缓冲区 ,还是写入第一个参数
  • 循环类型:我们执行的是点积类型的循环,还是逐元素操作?是否需要转置第二个参数
  • 逐元素操作类型:是乘法?加法?还是减法?

请注意,所有函数(除了点积)都对操作 是通用的。

样板代码

在开始之前,我们需要定义矩阵类型。我还会定义一个名为 mfloat 的类型,稍后我们可以将其替换为任何浮点类型。

#define DEBUG 1

typedef double mfloat;

typedef struct {
  mfloat *buf;
  int rows;
  int cols;
} matrix;

我们假设元素以行主序存储在 buf 中。现在,我们需要一些基本的 getter-setter 函数,包含可选的边界检查,将在库的其余部分使用。

static inline mfloat matrix_get(matrix m, int row, int col) {
  if (DEBUG)
    if (!(row >= 0 && col >= 0 && row < m.rows && col < m.cols)) {
      fprintf(
          stderr,
          "matrix_get: 索引越界 (%d, %d),矩阵尺寸为 "
          "(%d, %d)\n",
          row, col, m.rows, m.cols);
      exit(1);
    }

  return m.buf[row * m.cols + col];
}

static inline void matrix_set(matrix m, int row, int col,
                              mfloat val) {
  if (DEBUG)
    if (!(row >= 0 && col >= 0 && row < m.rows && col < m.cols)) {
      fprintf(
          stderr,
          "matrix_set: 索引越界 (%d, %d),矩阵尺寸为 "
          "(%d, %d)\n",
          row, col, m.rows, m.cols);
      exit(1);
    }

  m.buf[row * m.cols + col] = val;
}

我们还需要一种将矩阵打印到控制台的方法

static inline void matrix_print(matrix m) {
  for (int i = 0; i < m.rows; i++) {
    printf("[ ");
    for (int j = 0; j < m.cols; j++) {
      // 科学计数法,四舍五入到4位小数
      printf("%.4e", matrix_get(m, i, j));
      printf(" ");
    }
    printf("]\n");
  }
  printf("\n");
}

以及一种在堆上分配和释放矩阵的方法

matrix matrix_new(int rows, int cols) {
  double *buf = calloc(rows * cols, sizeof(double));
  if (buf == NULL) {
    printf("matrix_new: calloc 失败。");
    exit(1);
  }

  return (matrix){
      .buf = buf,
      .rows = rows,
      .cols = cols,
  };
}

## 点积

让我们从点积开始。我只写一个分配缓冲区的版本,因为只有当两个矩阵都是方阵时才能将输出写入第一个参数,而我们不能假设这一点。

```c
static inline void matrix_dot(matrix out, const matrix m1,
                              const matrix m2) {
  if (DEBUG)
    if (m1.cols != m2.rows) {
      printf(
          "矩阵点积:维度错误 (%d, %d) 与 "
          "(%d, %d) 不兼容\n",
          m1.rows, m1.cols, m2.rows, m2.cols);
      exit(1);
    }
  for (int row = 0; row < m1.rows; row++) {
    for (int col = 0; col < m2.cols; col++) {
      double sum = 0.0;
      for (int k = 0; k < m1.cols; k++) {
        double x1 = matrix_get(m1, row, k);
        double x2 = matrix_get(m2, k, col);
        sum += x1 * x2;
      }
      matrix_set(out, row, col, sum);
    }
  }
}

## 逐元素运算

每个逐元素运算执行以下步骤:

1. 确保两个矩阵维度相同
2. 对每个对应元素执行运算
3. 将结果放入输出矩阵

我们发现所有函数的循环结构都相同,因此为它编写一个宏:

```c
#define MAT_ELEMENTWISE_LOOP        \
  for (int i = 0; i < m1.rows; i++) \
    for (int j = 0; j < m1.cols; j++)

再编写一个边界检查函数,当维度不匹配时直接报错退出:

static inline void mat_bounds_check_elementwise(const matrix out,
                                                const matrix m1,
                                                const matrix m2) {
  if (DEBUG)
    if (m1.rows != m2.rows || m1.cols != m2.cols ||
        out.rows != m1.rows || out.cols != m1.cols) {
      fprintf(stderr,
              "逐元素运算维度不兼容 "
              "(%d, %d) & (%d, %d) => (%d, %d) \n",
              m1.rows, m1.cols, m2.rows, m2.cols, out.rows,
              out.cols);
      exit(1);
    }
}

现在我们要实现加法、乘法、除法和减法。由于除了实际计算部分外其他代码都相同,我们可以将其抽象成一个定义函数的宏:

#define DEF_MAT_ELEMENTWISE_BUF(opname, op)           \
  static inline void matrix_##opname(                 \
      matrix out, const matrix m1, const matrix m2) { \
    mat_bounds_check_elementwise(out, m1, m2);        \
    MAT_ELEMENTWISE_LOOP {                            \
      mfloat x = matrix_get(m1, i, j);                \
      mfloat y = matrix_get(m2, i, j);                \
      matrix_set(out, i, j, op);                      \
    }                                                 \
  }

##opname 会将 opname 的值插入到函数名中。

考虑到后续需要为所有变体定义加、乘、除、减函数,我们编写一个宏来为给定的函数定义宏生成所有操作:

#define DEF_ALL_OPS(OP_MACRO) \
  OP_MACRO(sub, (x - y));     \
  OP_MACRO(add, (x + y));     \
  OP_MACRO(div, (x / y));     \
  OP_MACRO(mul, (x * y));

现在我们可以用一行代码定义全部四个函数:

DEF_ALL_OPS(DEF_MAT_ELEMENTWISE_BUF)

完成!通过这一行代码,我们定义了 matrix_addmatrix_submatrix_divmatrix_mul 四个函数。

接下来实现原地运算:

static inline void mat_bounds_check_elementwise_ip(
    matrix m1, const matrix m2) {
  if (DEBUG)
    if (m1.rows != m2.rows || m1.cols != m2.cols) {
      fprintf(stderr,
              "原地逐元素运算维度不兼容 "
              "(%d, %d) & (%d, %d) \n",
              m1.rows, m1.cols, m2.rows, m2.cols);
      exit(1);
    }
}

#define DEF_MAT_ELEMENTWISE_IP(opname, op)                 \
  static inline void matrix_ip_##opname(matrix m1,         \
                                        const matrix m2) { \
    mat_bounds_check_elementwise_ip(m1, m2);               \
    MAT_ELEMENTWISE_LOOP {                                 \
      mfloat x = matrix_get(m1, i, j);                     \
      mfloat y = matrix_get(m2, i, j);                     \
      matrix_set(m1, i, j, op);                            \
    }                                                      \
  }

使用这个新宏,只需执行:

DEF_ALL_OPS(DEF_MAT_ELEMENTWISE_IP)

即可定义 matrix_ip_addmatrix_ip_mul 等函数。转置操作也可采用类似方式实现,此处不再展示。

一元运算

有时我们需要对矩阵 进行一元运算,例如标量运算 。让我们创建一个针对运算的通用一元函数。

#define DEF_MAT_UNARY_IP(opname, op)            \
  static inline void matrix_ip_##opname(matrix m1) { \
    MAT_ELEMENTWISE_LOOP {                           \
      mfloat x = matrix_get(m1, i, j);               \
      matrix_set(m1, i, j, op);                      \
    }                                                \
  }

DEF_MAT_UNARY_IP(square, (x * x))
DEF_MAT_UNARY_IP(negate, (-x))
DEF_MAT_UNARY_IP(sqrt, (sqrt(x)))

现在 matrix_ip_square(A)matrix_ip_negate(A) 等函数已定义。

总结

就这样,我们以最小的努力“编写”了一个完整的矩阵运算库。但这引出了一个问题:这段代码安全吗?

只要你对你定义的所有宏都进行 #undef 处理,那么是的,它和你自己编写所有函数一样安全。另一方面,如果你将宏作为库功能的一部分暴露出去,它可能就不再安全了。但仅仅因为某些代码安全,并不意味着你应该这样做。如果代码出现问题,调试起来可能会很困难,因为你实际上看不到它被展开成什么样子。因此,在生产环境中,强烈不建议使用宏。我在这里使用它,只是因为对于这个教育系列来说,它是一种简单的捷径。

如果你有任何问题或建议,欢迎留言或给我发邮件。感谢阅读。

✦ 本文的构思、研究、撰写和编辑均未使用大语言模型。