deepdrivemd.models.aae_stream.train

Functions

build_model(cfg)

main(cfg)

next_input(cfg, streams)

Read the next batch of contact maps from aggregated files.

train(train_loader, model, disc_optimizer, ...)

train_model(model, ae_optimizer, ...)

validate(valid_loader, model, device, cfg)

wait_for_input(cfg)

Wait for the expected number of sufficiently large agg.bp files to be produced.

deepdrivemd.models.aae_stream.train.build_model(cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig)
deepdrivemd.models.aae_stream.train.main(cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig)
deepdrivemd.models.aae_stream.train.next_input(cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig, streams: deepdrivemd.data.stream.aggregator_reader.Streams) Tuple[numpy.ndarray, numpy.ndarray]

Read the next batch of contact maps from aggregated files.

Returns

Tuple[np.ndarray, np.ndarray] – Training and validation sets.

deepdrivemd.models.aae_stream.train.train(train_loader, model: mdlearn.nn.models.aae.point_3d_aae.AAE3d, disc_optimizer, ae_optimizer, device, cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig)
deepdrivemd.models.aae_stream.train.train_model(model, ae_optimizer, disc_optimizer, train_loader, valid_loader, device, cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig)
deepdrivemd.models.aae_stream.train.validate(valid_loader, model: mdlearn.nn.models.aae.point_3d_aae.AAE3d, device, cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig)
deepdrivemd.models.aae_stream.train.wait_for_input(cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig) List[str]

Wait for the expected number of sufficiently large agg.bp files to be produced.

Returns

List[str] – List of paths to aggregated files.