TensorFlow的tf.split函数可用于切分张量(Tensor),将张量按照指定的维度分成多份。该函数的作用是将张量按照指定的维度分割成n份并打包成一个列表返回。
该函数的使用方法如下:
tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
参数说明:
- value:待切分的张量,可以是任意维度的张量。
- num_or_size_splits:代表切片数量的一个整数,或者每一份长度组成的列表(张量),如[2,3,4]表示将原始张量分为3份,每份的长度分别是2、3、4。当num_or_size_splits是一个整数时,表示将原始张量切成该整数数量的张量。如果该参数是一个列表,列表的长度必须能够被value的axis维度整除。
- axis:代表切分的维度的索引,比如2表示操作张量的第三个维度。
- num:将该参数指定为一个整数,则表示借助该参数与value的axis维度,将张量等长地切分成num份,若num==2,即在axis维度上等分成两份,如果value的axis维度无法被num整除,则会报错。
- name:操作名称,可选参数。
下面是tf.split函数的两个实际使用例子:
1.根据张量value,按照axis=1将其分为两个长度相等的张量output1和output2:
import tensorflow as tf
# 创建一个4*6的tensor
value = tf.constant([[1,2,3,4,5,6], [7,8,9,10,11,12], [13,14,15,16,17,18], [19,20,21,22,23,24]], dtype=tf.float32)
output1, output2 = tf.split(value, num_or_size_splits=2, axis=1)
print("output1: ", output1)
print("output2: ", output2)
输出结果:
output1: tf.Tensor(
[[ 1. 2. 3.]
[ 7. 8. 9.]
[13. 14. 15.]
[19. 20. 21.]], shape=(4, 3), dtype=float32)
output2: tf.Tensor(
[[ 4. 5. 6.]
[10. 11. 12.]
[16. 17. 18.]
[22. 23. 24.]], shape=(4, 3), dtype=float32)
2.将张量value根据传入的列表num_or_size_splits,将张量沿axis=0维度切分成6份:
import tensorflow as tf
# 创建一个4*8的tensor
value = tf.constant([[1,2,3,4,5,6,7],
[8,9,10,11,12,13,14],
[15,16,17,18,19,20,21],
[22,23,24,25,26,27,28]], dtype=tf.float32)
output = tf.split(value, num_or_size_splits=[1,2,3], axis=0)
print("output[0]:", output[0])
print("output[1]:", output[1])
print("output[2]:", output[2])
print("output[3]:", output[3])
print("output[4]:", output[4])
print("output[5]:", output[5])
输出结果:
output[0]: tf.Tensor([[1. 2. 3. 4. 5. 6. 7.]], shape=(1, 7), dtype=float32)
output[1]: tf.Tensor(
[[ 8. 9. 10. 11. 12. 13. 14.]
[15. 16. 17. 18. 19. 20. 21.]], shape=(2, 7), dtype=float32)
output[2]: tf.Tensor(
[[22. 23. 24. 25. 26. 27. 28.]], shape=(1, 7), dtype=float32)
output[3]: tf.Tensor([], shape=(0, 7), dtype=float32)
output[4]: tf.Tensor([], shape=(0, 7), dtype=float32)
output[5]: tf.Tensor([], shape=(0, 7), dtype=float32)
可以看到,由于num_or_size_splits列表加起来等于6,因此output将value切分成了6份,其中前三份分别是长度为1、2、3的张量,后三份是长度为0的张量。