deepdrivemd.models.aae_stream.train
Functions
|
|
|
|
|
Read the next batch of contact maps from aggregated files. |
|
|
|
|
|
|
|
Wait for the expected number of sufficiently large agg.bp files to be produced. |
- deepdrivemd.models.aae_stream.train.build_model(cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig)
- deepdrivemd.models.aae_stream.train.main(cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig)
- deepdrivemd.models.aae_stream.train.next_input(cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig, streams: deepdrivemd.data.stream.aggregator_reader.Streams) Tuple[numpy.ndarray, numpy.ndarray]
Read the next batch of contact maps from aggregated files.
- Returns
Tuple[np.ndarray, np.ndarray] – Training and validation sets.
- deepdrivemd.models.aae_stream.train.train(train_loader, model: mdlearn.nn.models.aae.point_3d_aae.AAE3d, disc_optimizer, ae_optimizer, device, cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig)
- deepdrivemd.models.aae_stream.train.train_model(model, ae_optimizer, disc_optimizer, train_loader, valid_loader, device, cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig)
- deepdrivemd.models.aae_stream.train.validate(valid_loader, model: mdlearn.nn.models.aae.point_3d_aae.AAE3d, device, cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig)
- deepdrivemd.models.aae_stream.train.wait_for_input(cfg: deepdrivemd.models.aae_stream.config.Point3dAAEConfig) List[str]
Wait for the expected number of sufficiently large agg.bp files to be produced.
- Returns
List[str] – List of paths to aggregated files.