【Tensorflow】Tensorflow中修改tensor的方法

2022-12-24 08:05:53

                                             Tensorflow中修改tensor的方法

在TensorFlow中tensor是不能直接修改数值的,如:

import tensorflow as tf
tensor_1 = tf.constant([x for x in range(1,10)])
# tensor_1 是一个数值为1到9的张量,希望把中间第五个数值改为0 
tensor_1[4] = 0 

这时就会报错,错误类型是:

TypeError: 'Tensor' object does not support item assignment

下面总结两种方法来修改tensor:

1、生成一个新的tensor

# 方法一 : 运用concat函数
tensor_1 = tf.constant([x for x in range(1,10)])
# 将原来的张量拆分为3部分,修改位置前的部分,要修改的部分和修改位置之后的部分
i = 4
part1 = tensor_1[:i]
part2 = tensor_1[i+1:]
val = tf.constant([0])
new_tensor = tf.concat([part1,val,part2], axis=0)

2、使用one_hot进行加减运算

# 方法二:使用one_hot来进行加减运算
tensor_1 = tf.constant([x for x in range(1,10)])
i = 4
# 生成一个one_hot张量,长度与tensor_1相同,修改位置为1
shape = tensor_1.get_shape().as_list()
one_hot = tf.one_hot(i,shape[0],dtype=tf.int32)
# 做一个减法运算,将one_hot为一的变为原张量该位置的值进行相减
new_tensor = tensor_1 - tensor_1[i] * one_hot

3、使用TensorFlow自带的assign()函数(修改的tensor必须为变量(Variable))

import tensorflow as tf

#create a Variable
x=tf.Variable(initial_value=[[1,1],[1,1]],dtype=tf.float32,validate_shape=False)

init_op=tf.global_variables_initializer()
update=tf.assign(x,[[1,2],[1,2]])

with tf.Session() as session:
    session.run(init_op)
    session.run(update)
    x=session.run(x)
    print(x)

tensorflow使用assign(variable,new_value)来更改变量的值,但是真正作用在garph中,必须要调用gpu或者cpu运行这个更新过:session.run(update) 。

 

参考

1、tensorflow更改变量的值

2、Tensorflow小技巧整理:修改张量特定元素的值

  • 作者:xiaohe9275
  • 原文链接:https://blog.csdn.net/xiaohe9275/article/details/80824359
    更新时间:2022-12-24 08:05:53