Tensorflow中修改tensor的方法
在TensorFlow中tensor是不能直接修改数值的,如:
import tensorflow as tf
tensor_1 = tf.constant([x for x in range(1,10)])
# tensor_1 是一个数值为1到9的张量,希望把中间第五个数值改为0
tensor_1[4] = 0
这时就会报错,错误类型是:
TypeError: 'Tensor' object does not support item assignment
下面总结两种方法来修改tensor:
1、生成一个新的tensor
# 方法一 : 运用concat函数
tensor_1 = tf.constant([x for x in range(1,10)])
# 将原来的张量拆分为3部分,修改位置前的部分,要修改的部分和修改位置之后的部分
i = 4
part1 = tensor_1[:i]
part2 = tensor_1[i+1:]
val = tf.constant([0])
new_tensor = tf.concat([part1,val,part2], axis=0)
2、使用one_hot进行加减运算
# 方法二:使用one_hot来进行加减运算
tensor_1 = tf.constant([x for x in range(1,10)])
i = 4
# 生成一个one_hot张量,长度与tensor_1相同,修改位置为1
shape = tensor_1.get_shape().as_list()
one_hot = tf.one_hot(i,shape[0],dtype=tf.int32)
# 做一个减法运算,将one_hot为一的变为原张量该位置的值进行相减
new_tensor = tensor_1 - tensor_1[i] * one_hot
3、使用TensorFlow自带的assign()函数(修改的tensor必须为变量(Variable))
import tensorflow as tf
#create a Variable
x=tf.Variable(initial_value=[[1,1],[1,1]],dtype=tf.float32,validate_shape=False)
init_op=tf.global_variables_initializer()
update=tf.assign(x,[[1,2],[1,2]])
with tf.Session() as session:
session.run(init_op)
session.run(update)
x=session.run(x)
print(x)
tensorflow使用assign(variable,new_value)来更改变量的值,但是真正作用在garph中,必须要调用gpu或者cpu运行这个更新过:session.run(update) 。
参考