tf.data.Dataset 기초
repeat(), batch(), take()
코드와 출력문을 보면 이해가 더 빠릅니다.
def count(stop):
i = 0
while i<stop:
yield i
i += 1
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(8).take(10):
print(count_batch.numpy())
# [0 1 2 3 4 5 6 7]
# [ 8 9 10 11 12 13 14 15]
# [16 17 18 19 20 21 22 23]
# [24 0 1 2 3 4 5 6]
# [ 7 8 9 10 11 12 13 14]
# [15 16 17 18 19 20 21 22]
# [23 24 0 1 2 3 4 5]
# [ 6 7 8 9 10 11 12 13]
# [14 15 16 17 18 19 20 21]
# [22 23 24 0 1 2 3 4]
간단 설명
repeat(): 데이터 셋을 반복한다.
batch(): 데이터 배치의 크기를 정한다. (위의 코드에서는 1배치 당 8개의 데이터)
take(): 해당 배치를 몇 번 불러올지 정한다. (배치를 10번 불러온다)
주의사항: 순서에 따라 결과가 다릅니다! 예를 들어, Dataset().batch().take() 와 Dataset().take().batch() 의 결과값이 다릅니다! 자세한 사항은 상세설명 6에 있습니다.
상세 설명
1. 함수 count
함수 count는 yield 키워드를 보면 알 수 있듯이, generator 입니다.
2. Dataset.from_generator(count)
count 제너레이터를 활용하여, tf.data.Dataset인 ds_counter를 만들었습니다.
3. repeat()
데이터 셋을 반복합니다. 뒤에서 설명하겠지만, batch 8 과 take 10 으로 인해 데이터를 총 80번 불러오게 됩니다. 위에서 만든 데이터 셋 ds_counter의 제너레이터는 0부터 24까지 숫자를 만듭니다. 하지만 이후, 종료되지 않고 다시 제너레이터가 반복됩니다.
출력문을 보면 0~24를 계속 반복하면서 만들어 내는 것을 볼 수 있습니다.
def count(stop):
i = 0
while i<stop:
yield i
i += 1
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.batch(8).take(10): # repeat()이 없다면
print(count_batch.numpy())
# [0 1 2 3 4 5 6 7]
# [ 8 9 10 11 12 13 14 15]
# [16 17 18 19 20 21 22 23]
# [24]
4. batch(n)
batch 사이즈를 n으로 설정합니다. 여기서는 batch(8)이므로, 숫자가 8개씩 들어갑니다. 이미지라면 1배치당 8개의 이미지가 들어갑니다.
만약 batch가 없다면 리스트에 담기지도 않고 데이터 한개씩 불러오게 됩니다.
def count(stop):
i = 0
while i<stop:
yield i
i += 1
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.take(50): # batch()가 없다면
print(count_batch.numpy())
# 0
# 1
# 2
# ...
# 24
# 0
# 1
# 2
# ...
# 24.
for count_batch in ds_counter.take(5): # batch()가 없다면
print(count_batch.numpy())
# 0
# 1
# 2
# 3
# 4
# 끝
5. take(m)
해당 배치를 m번 반복합니다. 여기서는 batch(10)이므로, 배치가 10번 반복됩니다. 위의 batch 8 * take 10 해서 총 80 개의 데이터 또는 이미지를 불러오게 됩니다.
만약 repeat()은 있고, take()가 없다면 특별한 조건이 없다면 계속 데이터셋을 불러오게 됩니다. 계속..... KeyboardInterrupt로 종료시키면 됩니다....
def count(stop):
i = 0
while i<stop:
yield i
i += 1
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(8): # take()가 없다면
print(count_batch.numpy())
# [0 1 2 3 4 5 6 7]
# [ 8 9 10 11 12 13 14 15]
# [16 17 18 19 20 21 22 23]
# [24 0 1 2 3 4 5 6]
# [ 7 8 9 10 11 12 13 14]
# [15 16 17 18 19 20 21 22]
# [23 24 0 1 2 3 4 5]
# [ 6 7 8 9 10 11 12 13]
# [14 15 16 17 18 19 20 21]
# [22 23 24 0 1 2 3 4]
# .... 계속 반복됩니다!!
6. 순서
순서에 따라 결과가 다릅니다. 예시와 출력문을 참고하세요!
def count(stop):
i = 0
while i<stop:
yield i
i += 1
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.batch(5).take(1): # 순서 유의!
print(count_batch.numpy())
# [0 1 2 3 4]
for count_batch in ds_counter.batch(1).take(5): # 순서 유의!
print(count_batch.numpy())
# [0]
# [1]
# [2]
# [3]
# [4]
참고하면 좋은 글
'Tensorflow' 카테고리의 다른 글
케라스 Conv-LSTM을 활용한 영상 예측 예제 (0) | 2020.11.03 |
---|---|
케라스 예제 번역: Simple MNIST convnet (0) | 2020.11.03 |
Keras 이미지 Extract features 및 Feature map 그리기 (0) | 2020.10.29 |
Tensorflow 케라스 EfficientNet Finetuning 예제 - 1 (0) | 2020.10.28 |
Tensorflow: input_tensor와 input_shape의 차이 (0) | 2020.10.28 |