Tensorflow

텐서플로우: tf.GradientTape.watch() 설명

카카오그래놀라 2020. 11. 6. 01:03

설명

텐서가 GradientTape(미분 계산 값을 기록하는 객체)에 의해 추적되도록 합니다.

인자 값으로는 텐서 또는 텐서가 담긴 리스트를 받습니다.

만약, 텐서가 아닐 경우, valueError를 띄웁니다.

 

Ensures that tensor is being traced by this tape.

 

  • g.watch(y)가 없는 경우

x = tf.constant(3.0)
y = tf.constant(3.0)

with tf.GradientTape(persistent=True) as g:
  g.watch(x)
  
  z0 = x ** 2
  z1 = y ** 2

dz_dx = g.gradient(z0, x)  # x^2의 미분값은 2*x (x는 위에서 3이기에 결과는 6)
print(dz_dx) # tf.Tensor(6.0, shape=(), dtype=float32)

dz_dy = g.gradient(z1, y)  # dz_dy는 None 임!!
print(dz_dy) # None

 

  • g.watch(y)가 있는 경우

x = tf.constant(3.0)
y = tf.constant(3.0)

with tf.GradientTape(persistent=True) as g:
  g.watch(x)
  g.watch(y) # 추가함!!!
  
  z0 = x ** 2
  z1 = y ** 2

dz_dx = g.gradient(z0, x)
print(dz_dx) # tf.Tensor(6.0, shape=(), dtype=float32)

dz_dy = g.gradient(z1, y)  # watch를 추가했더니 
print(dz_dy) # tf.Tensor(6.0, shape=(), dtype=float32) 이제 return 됨!

 

 

공식 문서

 

tf.GradientTape  |  TensorFlow Core v2.3.0

Record operations for automatic differentiation.

www.tensorflow.org

소스 코드

 

tensorflow/tensorflow

An Open Source Machine Learning Framework for Everyone - tensorflow/tensorflow

github.com