Tensorflow

텐서플로우 Dataset: repeat(), batch(), take()

카카오그래놀라 2020. 10. 30. 01:06

tf.data.Dataset 기초

repeat(), batch(), take()

코드와 출력문을 보면 이해가 더 빠릅니다.
def count(stop):
    i = 0
    while i<stop:
        yield i
        i += 1
    
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )

for count_batch in ds_counter.repeat().batch(8).take(10):
    print(count_batch.numpy())
 
# [0 1 2 3 4 5 6 7]
# [ 8  9 10 11 12 13 14 15]
# [16 17 18 19 20 21 22 23]
# [24  0  1  2  3  4  5  6]
# [ 7  8  9 10 11 12 13 14]
# [15 16 17 18 19 20 21 22]
# [23 24  0  1  2  3  4  5]
# [ 6  7  8  9 10 11 12 13]
# [14 15 16 17 18 19 20 21]
# [22 23 24  0  1  2  3  4]

 

간단 설명

repeat(): 데이터 셋을 반복한다.

batch(): 데이터 배치의 크기를 정한다. (위의 코드에서는 1배치 당 8개의 데이터)

take(): 해당 배치를 몇 번 불러올지 정한다. (배치를 10번 불러온다)

 

주의사항: 순서에 따라 결과가 다릅니다! 예를 들어, Dataset().batch().take() 와 Dataset().take().batch() 의 결과값이 다릅니다! 자세한 사항은 상세설명 6에 있습니다.

 

 


상세 설명

1. 함수 count

함수 count는 yield 키워드를 보면 알 수 있듯이, generator 입니다.

 

2. Dataset.from_generator(count)

count 제너레이터를 활용하여, tf.data.Dataset인 ds_counter를 만들었습니다.

 

3. repeat()

데이터 셋을 반복합니다. 뒤에서 설명하겠지만, batch 8 과 take 10 으로 인해 데이터를 총 80번 불러오게 됩니다. 위에서 만든 데이터 셋 ds_counter의 제너레이터는 0부터 24까지 숫자를 만듭니다. 하지만 이후, 종료되지 않고 다시 제너레이터가 반복됩니다.

출력문을 보면 0~24를 계속 반복하면서 만들어 내는 것을 볼 수 있습니다.

def count(stop):
    i = 0
    while i<stop:
        yield i
        i += 1
    
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )

for count_batch in ds_counter.batch(8).take(10): # repeat()이 없다면
    print(count_batch.numpy())
 
# [0 1 2 3 4 5 6 7]
# [ 8  9 10 11 12 13 14 15]
# [16 17 18 19 20 21 22 23]
# [24]

 

4. batch(n)

batch 사이즈를 n으로 설정합니다. 여기서는 batch(8)이므로, 숫자가 8개씩 들어갑니다. 이미지라면 1배치당 8개의 이미지가 들어갑니다.

만약 batch가 없다면 리스트에 담기지도 않고 데이터 한개씩 불러오게 됩니다. 

def count(stop):
    i = 0
    while i<stop:
        yield i
        i += 1
    
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )

for count_batch in ds_counter.take(50): # batch()가 없다면
    print(count_batch.numpy())
 
# 0
# 1
# 2
# ...
# 24
# 0
# 1
# 2
# ...
# 24.

for count_batch in ds_counter.take(5): # batch()가 없다면
    print(count_batch.numpy())
 
# 0
# 1
# 2
# 3
# 4 
# 끝

 

5. take(m)

해당 배치를 m번 반복합니다. 여기서는 batch(10)이므로, 배치가 10번 반복됩니다. 위의 batch 8 * take 10 해서 총 80 개의 데이터 또는 이미지를 불러오게 됩니다.

 만약 repeat()은 있고, take()가 없다면 특별한 조건이 없다면 계속 데이터셋을 불러오게 됩니다. 계속..... KeyboardInterrupt로 종료시키면 됩니다....

def count(stop):
    i = 0
    while i<stop:
        yield i
        i += 1
    
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )

for count_batch in ds_counter.repeat().batch(8): # take()가 없다면
    print(count_batch.numpy())
 
# [0 1 2 3 4 5 6 7]
# [ 8  9 10 11 12 13 14 15]
# [16 17 18 19 20 21 22 23]
# [24  0  1  2  3  4  5  6]
# [ 7  8  9 10 11 12 13 14]
# [15 16 17 18 19 20 21 22]
# [23 24  0  1  2  3  4  5]
# [ 6  7  8  9 10 11 12 13]
# [14 15 16 17 18 19 20 21]
# [22 23 24  0  1  2  3  4]
# .... 계속 반복됩니다!!

 

6. 순서

순서에 따라 결과가 다릅니다. 예시와 출력문을 참고하세요!

def count(stop):
    i = 0
    while i<stop:
        yield i
        i += 1
    
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )

for count_batch in ds_counter.batch(5).take(1): # 순서 유의!
    print(count_batch.numpy())
 
# [0 1 2 3 4]

for count_batch in ds_counter.batch(1).take(5): # 순서 유의!
    print(count_batch.numpy())

# [0]
# [1]
# [2]
# [3]
# [4]

 


참고하면 좋은 글

 

텐서플로우 Dataset: from_generator 설명

사용 이유 1. 메모리 용량의 한계 때문에 - 이미지, 텍스트 데이터 등을 메모리에 올려 놓은 후 작업을 진행할 수 있지만, 수많은 파일이 있을 경우 이를 모두 메모리에 올려놓지 못하기에, generato

deep-deep-deep.tistory.com