В официальном руководстве по Tensorflow мы нашли два кратких руководства по
И после того, как мы рассмотрели оба руководства, мы обнаружили, что основное различие между ними заключается в том, что новичок использовал model.fit()
, а другой эксперт использовал пользовательский цикл обучения train_step()
с использованием tf.GradientTape()
.
В этой статье мы расскажем о двух простых примерах обоих решений и укажем на несколько сценариев, в которых разработчик ДОЛЖЕН использовать собственный цикл обучения.
API высокого уровня: model.fit()
Вот пример того, как вы можете использовать model.fit для обучения модели:
# Define the model model = MyModel() # Compile the model with a loss function and an optimizer model.compile(loss=tf.losses.mean_squared_error, optimizer=tf.optimizers.SGD(learning_rate=0.001)) # Generate some synthetic data x = np.random.rand(64, 10) y = np.random.rand(64, 1) # Use the model.fit method to train the model model.fit(x, y, epochs=10, batch_size=32)
Пользовательский цикл обучения: tf.GradientTape()
# Define the model model = MyModel() # Compile the model with a loss function and an optimizer model.compile(loss=tf.losses.mean_squared_error, optimizer=tf.optimizers.SGD(learning_rate=0.001)) # Generate some synthetic data x = np.random.rand(64, 10) y = np.random.rand(64, 1) # Use the model to make predictions with tf.GradientTape() as tape: logits = model(x) loss_value = tf.losses.mean_squared_error(y, logits) # Compute gradients grads = tape.gradient(loss_value, model.trainable_variables) # Use the optimizer to update the model's weights optimizer.apply_gradients(zip(grads, model.trainable_variables))
Заключение и совет
Нет необходимости писать собственный обучающий цикл и использовать tf.GradientTape()
самостоятельно в TensorFlow. Метод model.fit()
предоставляет удобный и высокоуровневый интерфейс для обучающих моделей, и он будет обрабатывать низкоуровневые детали процесса обучения за вас, включая вычисление градиентов и обновление весов модели.
Однако иногда может потребоваться больший контроль над процессом обучения или выполнение дополнительных операций во время обучения. В этих случаях вы можете написать собственный цикл обучения, используя tf.GradientTape()
.
Вот несколько примеров того, что вы можете настроить в пользовательском цикле обучения:
- Пакетная обработка. При использовании пользовательского цикла обучения вы можете выбрать способ разделения данных на пакеты и порядок повторения этих пакетов. Напротив, при использовании
model.fit()
вы можете указать размер пакета, и методmodel.fit()
выполнит пакетирование за вас. - Перемешивание. При использовании пользовательского цикла обучения вы можете выбрать, следует ли перемешивать данные и как их перемешивать. Напротив, при использовании
model.fit()
вы можете указать, следует ли перетасовывать данные, и методmodel.fit()
сделает это за вас. - Проверка. При использовании пользовательского цикла обучения вы можете выбрать способ разделения данных на наборы для обучения и проверки и способ оценки модели в наборе для проверки. Напротив, при использовании
model.fit,
вы можете указать разделение проверки, и метод model.fit выполнит проверку за вас. - Ведение журнала. При использовании пользовательского цикла обучения вы можете выбрать, какую информацию регистрировать и как ее регистрировать. Напротив, при использовании model.fit вы можете указать
tf.keras.callbacks.EarlyStopping
для обработки ведения журнала. - Ранняя остановка. При использовании пользовательского цикла обучения вы можете реализовать раннюю остановку, проверив производительность на проверочном наборе и остановив обучение, если производительность не улучшается. Напротив, при использовании model.fit вы можете использовать обратный вызов
tf.keras.callbacks.EarlyStopping
для обработки ранней остановки. - Сохранение и восстановление. При использовании пользовательского цикла обучения вы можете выбрать способ сохранения и восстановления модели во время и после обучения. Напротив, при использовании model.fit вы можете использовать
tf.keras.callbacks.ModelCheckpoint
callback для обработки сохранения и восстановления модели за вас.
Используя пользовательский цикл обучения, вы полностью контролируете процесс обучения и можете реализовать любую пользовательскую логику, которая вам нужна. Однако написание пользовательского обучающего цикла требует больше кода и более подвержено ошибкам, чем использование метода model.fit, который обрабатывает многие из этих деталей за вас.