使用 NDArray 进行数据交互

导入 np (numpy类似) 模块和 npx (numpy 扩展) 模块。

In [1]:
from mxnet import np, npx
# 启用 MXNet 的 numpy 兼容模式。
npx.set_np()  

创建一个向量并访问属性。

In [2]:
x = np.arange(12)
x
Out[2]:
array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.])
In [3]:
x.shape
Out[3]:
(12,)
In [4]:
x.size
Out[4]:
12

更多创建的方式。

In [5]:
np.zeros((3, 4))
Out[5]:
array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])
In [6]:
np.array([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
Out[6]:
array([[2., 1., 4., 3.],
       [1., 2., 3., 4.],
       [4., 3., 2., 1.]])
In [7]:
np.random.normal(0, 1, size=(3, 4))
Out[7]:
array([[ 2.2122064 ,  0.7740038 ,  1.0434405 ,  1.1839255 ],
       [ 1.8917114 , -1.2347414 , -1.771029  , -0.45138445],
       [ 0.57938355, -1.856082  , -1.9768796 , -0.20801921]])

按元素的操作。

In [8]:
x = np.array([1, 2, 4, 8])
y = np.ones_like(x) * 2
print('x =', x)
print('x + y', x + y)
print('x - y', x - y)
print('x * y', x * y)
print('x ** y', x ** y)
print('x / y', x / y)
x = [1. 2. 4. 8.]
x + y [ 3.  4.  6. 10.]
x - y [-1.  0.  2.  6.]
x * y [ 2.  4.  8. 16.]
x ** y [ 1.  4. 16. 64.]
x / y [0.5 1.  2.  4. ]

矩阵乘法。

In [9]:
x = np.arange(12).reshape((3,4))
y = np.array([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
np.dot(x, y.T)
Out[9]:
array([[ 18.,  20.,  10.],
       [ 58.,  60.,  50.],
       [ 98., 100.,  90.]])

沿着某个坐标的合并。

In [10]:
np.concatenate([x, y], axis=0), np.concatenate([x, y], axis=1)
Out[10]:
(array([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [ 2.,  1.,  4.,  3.],
        [ 1.,  2.,  3.,  4.],
        [ 4.,  3.,  2.,  1.]]),
 array([[ 0.,  1.,  2.,  3.,  2.,  1.,  4.,  3.],
        [ 4.,  5.,  6.,  7.,  1.,  2.,  3.,  4.],
        [ 8.,  9., 10., 11.,  4.,  3.,  2.,  1.]]))

广播机制。

In [11]:
a = np.arange(3).reshape((3, 1))
b = np.arange(2).reshape((1, 2))
print('a:\n', a)
print('b:\n', b)
a + b
a:
 [[0.]
 [1.]
 [2.]]
b:
 [[0. 1.]]
Out[11]:
array([[0., 1.],
       [1., 2.],
       [2., 3.]])

访问元素。

In [12]:
print('x[-1] =\n', x[-1])
print('x[1:3] =\n', x[1:3])
print('x[1:3, 2:4] =\n', x[1:3, 2:4])
print('x[1,2] =', x[1,2])
x[-1] =
 [ 8.  9. 10. 11.]
x[1:3] =
 [[ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]]
x[1:3, 2:4] =
 [[ 6.  7.]
 [10. 11.]]
x[1,2] = 6.0

mxnet.numpy.ndarraynumpy.ndarray 互换。

In [13]:
a = x.asnumpy()
print(type(a))
b = np.array(a)
print(type(b))
<class 'numpy.ndarray'>
<class 'mxnet.numpy.ndarray'>