В официальном руководстве по Tensorflow мы нашли два кратких руководства по

И после того, как мы рассмотрели оба руководства, мы обнаружили, что основное различие между ними заключается в том, что новичок использовал model.fit(), а другой эксперт использовал пользовательский цикл обучения train_step() с использованием tf.GradientTape().

В этой статье мы расскажем о двух простых примерах обоих решений и укажем на несколько сценариев, в которых разработчик ДОЛЖЕН использовать собственный цикл обучения.

API высокого уровня: model.fit()

Вот пример того, как вы можете использовать model.fit для обучения модели:

# Define the model
model = MyModel()

# Compile the model with a loss function and an optimizer
model.compile(loss=tf.losses.mean_squared_error, optimizer=tf.optimizers.SGD(learning_rate=0.001))

# Generate some synthetic data
x = np.random.rand(64, 10)
y = np.random.rand(64, 1)

# Use the model.fit method to train the model
model.fit(x, y, epochs=10, batch_size=32)

Пользовательский цикл обучения: tf.GradientTape()

# Define the model
model = MyModel()

# Compile the model with a loss function and an optimizer
model.compile(loss=tf.losses.mean_squared_error, optimizer=tf.optimizers.SGD(learning_rate=0.001))

# Generate some synthetic data
x = np.random.rand(64, 10)
y = np.random.rand(64, 1)

# Use the model to make predictions
with tf.GradientTape() as tape:
  logits = model(x)
  loss_value = tf.losses.mean_squared_error(y, logits)

# Compute gradients
grads = tape.gradient(loss_value, model.trainable_variables)

# Use the optimizer to update the model's weights
optimizer.apply_gradients(zip(grads, model.trainable_variables))

Заключение и совет

Нет необходимости писать собственный обучающий цикл и использовать tf.GradientTape() самостоятельно в TensorFlow. Метод model.fit() предоставляет удобный и высокоуровневый интерфейс для обучающих моделей, и он будет обрабатывать низкоуровневые детали процесса обучения за вас, включая вычисление градиентов и обновление весов модели.

Однако иногда может потребоваться больший контроль над процессом обучения или выполнение дополнительных операций во время обучения. В этих случаях вы можете написать собственный цикл обучения, используя tf.GradientTape().

Вот несколько примеров того, что вы можете настроить в пользовательском цикле обучения:

  • Пакетная обработка. При использовании пользовательского цикла обучения вы можете выбрать способ разделения данных на пакеты и порядок повторения этих пакетов. Напротив, при использовании model.fit() вы можете указать размер пакета, и метод model.fit() выполнит пакетирование за вас.
  • Перемешивание. При использовании пользовательского цикла обучения вы можете выбрать, следует ли перемешивать данные и как их перемешивать. Напротив, при использовании model.fit() вы можете указать, следует ли перетасовывать данные, и метод model.fit() сделает это за вас.
  • Проверка. При использовании пользовательского цикла обучения вы можете выбрать способ разделения данных на наборы для обучения и проверки и способ оценки модели в наборе для проверки. Напротив, при использовании model.fit,вы можете указать разделение проверки, и метод model.fit выполнит проверку за вас.
  • Ведение журнала. При использовании пользовательского цикла обучения вы можете выбрать, какую информацию регистрировать и как ее регистрировать. Напротив, при использовании model.fit вы можете указать tf.keras.callbacks.EarlyStoppingдля обработки ведения журнала.
  • Ранняя остановка. При использовании пользовательского цикла обучения вы можете реализовать раннюю остановку, проверив производительность на проверочном наборе и остановив обучение, если производительность не улучшается. Напротив, при использовании model.fit вы можете использовать обратный вызов tf.keras.callbacks.EarlyStopping для обработки ранней остановки.
  • Сохранение и восстановление. При использовании пользовательского цикла обучения вы можете выбрать способ сохранения и восстановления модели во время и после обучения. Напротив, при использовании model.fit вы можете использовать tf.keras.callbacks.ModelCheckpointcallback для обработки сохранения и восстановления модели за вас.

Используя пользовательский цикл обучения, вы полностью контролируете процесс обучения и можете реализовать любую пользовательскую логику, которая вам нужна. Однако написание пользовательского обучающего цикла требует больше кода и более подвержено ошибкам, чем использование метода model.fit, который обрабатывает многие из этих деталей за вас.