from neon.util.persist import load_obj |
pre_trained_model = load_obj(filepath) |
pre_trained_layers = pre_trained_model['model']['config']['layers'] |
new_layers = [l for l in new_model.layers.layers] |
for i, layer in enumerate(new_layers): |
if load_pre_trained_weight(i, layer): |
layer.load_weights(pre_trained_layers, load_states=True) |
def main(): |
# Collect the user arguments and hyper parameters |
args, hyper_params = get_args_and_hyperparameters() |
# setup the CPU or GPU backend |
be = gen_backend(**extract_valid_args(args, gen_backend)) |
# load the training dataset. This will download the dataset |
# from the web and cache it locally for subsequent use. |
train_set = MultiscaleSampler('trainval', '2007', ...) |
# create the model by replacing the classification layer |
# of AlexNet with new adaptation layers |
model, opt = create_model( args, hyper_params) |
# Seed the Alexnet conv layers with pre-trained weights |
if args.model_file is None and hyper_params.use_pre_trained_weights: |
load_imagenet_weights(model, args.data_dir) |
train( args, hyper_params, model, opt, train_set) |
# Load the test dataset. This will download the dataset |
# from the web and cache it locally for subsequent use. |
test_set = MultiscaleSampler('test', '2007', ...) |
test( args, hyper_params, model, test_set) |
return |
欢迎光临 电子技术论坛_中国专业的电子工程师学习交流社区-中电网技术论坛 (http://bbs.eccn.com/) | Powered by Discuz! 7.0.0 |