AI/TensorFlow & Keras

[Keras] Keras Applications 사용 방법

#자유의날개 2022. 4. 29. 10:19
반응형
  • 케라스 어플리케이션스는 ImageNet 데이터셋(1400만 개 이상의 이미지, 2만 개 이상의 클래스)으로 학습한 이미지 분류 모델들을 사용할 수 있도록 해줌
  • ImageNet 데이터셋으로 학습한 가중치를 사용하지 않고 새로운 데이터로 학습시킬 수 있음
  • 입력 및 출력 레이어를 변경할 수 있고, 다른 신경망을 이어붙일 수 있음
  • 성능이 뛰어난 모델들이지만 파라미터 수가 워낙 많기 때문에(1M~100M) 학습 시간이 매우 오래 걸리거나 오버피팅되는 모델을 만들 수 있음
  • 모델의 구조를 본따서 직접 만든 모델이 효율적일 수 있음

사용 방법

 

Keras documentation: Keras Applications

Keras Applications Keras Applications are deep learning models that are made available alongside pre-trained weights. These models can be used for prediction, feature extraction, and fine-tuning. Weights are downloaded automatically when instantiating a mo

keras.io

 

  • 모델 이름을 import 만 하면 사용할 수 있음
# ResNet50 모델 사용하기
from tensorflow.keras.applications.resnet50 import ResNet50
model = ResNet50(weights=None)

# 모델 구조 출력
model.summary()

 

  • 주요 인수
  1. include_top : CNN 신경망의 마지막에 FCN을 넣을지 여부 확인, 직접 만들어서 넣을 경우는 False, 제공해주는 FCN을 사용할 경우는 True
  2. weights : 모델 구조만 사용하여 새로운 데이터에 대해 학습을 진행할 경우 임의의 초기값 설정은 None
  3. input_shape : 인풋 이미지의 형태, include_top=Fasle 여야 함
# InceptionV3 모델 사용하기
from tensorflow.keras.applications.inception_v3 import InceptionV3
model = InceptionV3(input_tensor=input_tensor, weights=None, include_top=False)

# 모델 구조 출력
model.summary()

 

  • 모델 커스터마이징
  1. 사용하고자 하는 어플레이션의 모델을 import
  2. functional API 방식으로 모델 커스터마이징
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model

# VGG16 모델 사용하기
from tensorflow.keras.applications.vgg16 import VGG16

# 1. 모델 구조만 사용하고 새로운 데이터로 학습 : weights=None
# 2. FCN은 직접 만들어서 연결 : include_top=False
# 3. 인풋 이미지 형태 : input_shape=(100, 100, 3)
model = VGG16(weights=None, include_top=False, input_shape=(100, 100, 3))

# 새로운 FCN 혹은 레이어 이어 붙이기
x = model.output
x = Dense(1024, activation='relu')(x)
x = Dense(512, activation='relu')(x)

# 분류 모델을 만들기 위해 출력 레이어 설정
predictions = Dense(10, activation='softmax')(x)

# 모델 정의
model = Model(inputs=model.input, outputs=predictions, name="VGG16_ver2")
model.summary()
반응형