xrbm.train package

Submodules

xrbm.train.cdk module

Contrastive Divergence Gradient Approximator

class xrbm.train.cdk.CDApproximator(learning_rate, momentum=0, k=1, regularizer=None)[source]

Bases: object

Contrastive Divergence Gradient Approximator

apply_updates(model, grads)[source]

Updates the model parameters based on the given gradients, using momentum

compute_gradients(cost, params, var_list=None)[source]

Computes the gradients of the given cost function w.r.t the tensors in the params

train(model, vis_data, in_data=[], global_step=None, var_list=None, name=None)[source]

Performs one step of the CD-k algorithm to approximate the model parameters

Module contents