tf.tensor
2025年02月11日
一、认识
在 TensorFlow
中,Tensor
(张量)是数据的基本单位,可以理解为多维数组。它类似于 NumPy
的 ndarray
,但与 NumPy
数组相比,Tensor
具有以下特点:
-
支持
GPU/TPU
加速:Tensor
可以在CPU
、GPU
或TPU
上运行,提高计算效率。 -
不可变性(
Immutable)
:Tensor
是不可变的,一旦创建,就无法修改其内容(但可以创建新的Tensor
)。 -
自动微分:
TensorFlow
通过tf.GradientTape
记录计算过程,以支持自动求导。
二、语法
tf.tensor(values, shape?, dtype?)
三、用法
2.1 创建张量
const tf = require("@tensorflow/tfjs");
// 创建一个标量张量(单个数值)
const t0 = tf.tensor(1);
// 打印张量 t0 的值
t0.print();
console.log("t0", t0);
// 创建一个一维张量(数组)
const t1 = tf.tensor([1, 2, 3, 4]);
// 打印张量 t1 的值
t1.print();
console.log("t1", t1);
// 创建一个二维张量(矩阵)
const t2 = tf.tensor([
[1, 2],
[3, 4],
[5, 6],
]);
// 打印张量 t2 的值
t2.print();
console.log("t2", t2);
// 创建一个三维张量
const t3 = tf.tensor([
[
[1, 2],
[3, 4],
],
[
[5, 6],
[7, 8],
],
]);
// 打印张量 t3 的值
t3.print();
console.log("t3", t3);
2.2 张量运算
const x = tf.tensor([1, 2, 3]);
const y = tf.tensor([4, 5, 6]);
x.add(y).print(); // [5, 7, 9]
x.mul(y).print(); // [4, 10, 18]
x.sub(y).print(); // [-3, -3, -3]
x.div(y).print(); // [0.25, 0.4, 0.5]
x.sum().print(); // 1 + 2 + 3 = 6
x.mean().print(); // (1 + 2 + 3) / 3 = 2