IMP  2.0.1
The Integrative Modeling Platform
RMSDClustering.h
Go to the documentation of this file.
1 /**
2  * \file IMP/multifit/RMSDClustering.h
3  * \brief Cluster transformations by rmsd
4  *
5  * Copyright 2007-2013 IMP Inventors. All rights reserved.
6  *
7  */
8 
9 #ifndef IMPMULTIFIT_RMSD_CLUSTERING_H
10 #define IMPMULTIFIT_RMSD_CLUSTERING_H
11 
12 #include <IMP/multifit/multifit_config.h>
13 #include "GeometricHash.h"
14 #include <IMP/algebra/Vector3D.h>
17 #include <IMP/core/XYZ.h>
18 #include <boost/graph/adjacency_list.hpp>
19 #include <IMP/atom/distance.h>
20 IMPMULTIFIT_BEGIN_NAMESPACE
21 
22 //! RMSD clustering
23 /**
24  /note Iteratively joins pairs of close transformations. The algorithm first
25  clusters transformations for which the transformed centroids are close
26  (fall into the same bin in a hash). Then, all clusters are globally
27  reclustered.
28  /note TransT should implement the functions:
29  join_into() add a transformation to the current cluster and
30  possibly updates the representative transformation for the
31  cluster
32  get_score() that returns the score (higher score is better)
33  update_score() that updates the score of the
34  cluster according to a new member
35  get_representative_transformation() a function that returns the
36  the representative transformation for a cluster
37 
38 */
39 template <class TransT>
41 private:
42 //! Base class for transformation record
43 class TransformationRecord {
44 public:
45  ////standard constructor.
46  inline TransformationRecord(const TransT &trans):
47  valid_(true), trans_(trans) {
48  }
49  virtual ~TransformationRecord() {}
50  //! Join the transformations into this.
51  void join_into(const TransformationRecord& record) {
52  trans_.update_score(record.trans_.get_score());
53  trans_.join_into(record.trans_);
54  }
55  inline float get_score() const { return trans_.get_score();}
56  const algebra::Vector3D get_centroid() const { return centroid_; }
57  void set_centroid(algebra::Vector3D& centroid) {
58  centroid_ = trans_.get_representative_transformation().get_transformed(
59  centroid); }
60  TransT get_record() const {return trans_;}
61  bool get_valid() const {return valid_;}
62  void set_valid(bool v) {valid_=v;}
63 protected:
64  bool valid_;
65  TransT trans_;
66  algebra::Vector3D centroid_;
67 };
68 typedef std::vector<TransformationRecord> TransformationRecords;
69 public:
70  typedef GeometricHash<int, 3> Hash3;
71  typedef boost::property<boost::edge_weight_t, short> ClusEdgeWeightProperty;
72  typedef boost::property<boost::vertex_index_t, int> ClusVertexIndexProperty;
73  // Graph type
74  typedef boost::adjacency_list<boost::vecS, boost::vecS, boost::undirectedS,
75  ClusVertexIndexProperty, ClusEdgeWeightProperty> Graph;
76  typedef boost::graph_traits<Graph> RCGTraits;
77  typedef RCGTraits::vertex_descriptor RCVertex;
78  typedef RCGTraits::edge_descriptor RCEdge;
79  typedef RCGTraits::vertex_iterator RCVertexIt;
80  typedef RCGTraits::edge_iterator RCEdgeIt;
81 
82  struct sort_by_weight {
83  bool operator()(const std::pair<RCEdge,float> &s1,
84  const std::pair<RCEdge,float> &s2) const {
85  return s1.second < s2.second;
86  }
87  };
88  /**
89  \param[in] bin_size the radius of the bins of the hash
90  differ with at most this value
91  */
92  RMSDClustering(float bin_size=3.){is_ready_=false;bin_size_=bin_size;}
93  virtual ~RMSDClustering() {}
94  //! cluster transformations
95  void cluster(float max_dist, const std::vector<TransT>& input_trans,
96  std::vector<TransT>& output_trans);
97  //! prepare for clustering
98  void prepare(const ParticlesTemp &ps);
99  void set_bin_size(float bin_size) {bin_size_=bin_size;}
100 protected:
101  //! returns the RMSD between two transformations with respect to
102  //! the stored points
103  virtual float get_squared_distance(const TransT& trans1,
104  const TransT& trans2);
105  //clustering function
106 void build_graph(const Hash3::PointList &inds,
107  const std::vector<TransformationRecord*> &recs,
108  float max_dist,
109  Graph &g);
110 
111  void build_full_graph(const Hash3 &h,
112  const std::vector<TransformationRecord*> &recs,
113  float max_dist, Graph &g);
114 
115  int cluster_graph(Graph &g,
116  const std::vector<TransformationRecord*> &recs,
117  float max_dist);
118 
119  int fast_clustering(float max_dist,
120  std::vector<TransformationRecord *>& recs);
121 
122  virtual int exhaustive_clustering(float max_dist,
123  std::vector<TransformationRecord *>& recs);
124  //! Remove transformations which are not valid.
125  // should be used after each invocation of work.
126  virtual void clean(std::vector<TransformationRecord*>*& records);
127  bool is_fast_;
128  float bin_size_; //hash bin size
129  // The centroid of the molecule
130  algebra::Vector3D centroid_;
131  Particles ps_;
132  core::XYZs xyzs_;
133  //fast RMSD computation
134  atom::RMSDCalculator rmsd_calc_;
135  bool is_ready_;
136 };
137 
138 template<class TransT> float
140  const TransT& trans2) {
141  return rmsd_calc_.get_squared_rmsd(trans1.get_representative_transformation(),
142  trans2.get_representative_transformation());
143 }
144 
145 template<class TransT>
146 void RMSDClustering<TransT>::build_graph(const Hash3::PointList &inds,
147  const std::vector<TransformationRecord*> &recs,
148  float max_dist, Graph &g){
149  //hash all the records
150  float max_dist2=max_dist*max_dist;
151  //add nodes
152 IMP_LOG_VERBOSE("build_graph:adding nodes"<<std::endl);
153  std::vector<RCVertex> nodes(inds.size());
154  for (unsigned int i=0; i<inds.size(); ++i) {
155  nodes[i]=boost::add_vertex(i,g);
156  }
157  //add edges
158  IMP_LOG_VERBOSE("build_graph:adding edges"<<std::endl);
159  for (unsigned int i=0; i<inds.size(); ++i) {
160  for (unsigned int j=i+1; j<inds.size(); ++j) {
161  float d2 = get_squared_distance(recs[i]->get_record(),
162  recs[j]->get_record());
163  if (d2 < max_dist2) {
164  boost::add_edge(nodes[i],nodes[j],d2,g);
165  //edge_weight.push_back(std::pair<RCEdge,float>(e,d2));
166  }}}
167  IMP_LOG_VERBOSE("build_graph: done building"<<std::endl);
168 }
169 template<class TransT>
170 void RMSDClustering<TransT>::build_full_graph(const Hash3 &h,
171  const std::vector<TransformationRecord*> &recs,
172  float max_dist, Graph &g){
173  float max_dist2=max_dist*max_dist;
174  //add nodes
175  std::vector<RCVertex> nodes(recs.size());
176  for (unsigned int i=0; i<recs.size(); ++i) {
177  nodes[i]=boost::add_vertex(i,g);
178  }
179  //add edges
180  for (int i = 0 ; i < (int)recs.size() ; ++i) {
181  TransT tr=recs[i]->get_record();
182  algebra::Transformation3D t = tr.get_representative_transformation();
183  Hash3::HashResult result =
184  h.neighbors(Hash3::INF, t.get_transformed(centroid_), max_dist);
185  for ( size_t k=0; k<result.size(); ++k ) {
186  int j = result[k]->second;
187  if (i >= j) continue; //insert edge only once
188  float centroids_dist2 = algebra::get_squared_distance(
189  recs[i]->get_centroid(),
190  recs[j]->get_centroid());
191  if (centroids_dist2 < max_dist2) {
192  float d2 = get_squared_distance(recs[i]->get_record(),
193  recs[j]->get_record());
194  if (d2 < max_dist2) {
195  boost::add_edge(nodes[i],nodes[j],d2,g);
196  }
197  }}}}
198 
199 template<class TransT>
200 int RMSDClustering<TransT>::cluster_graph(Graph &g,
201  const std::vector<TransformationRecord*> &recs,
202  float max_dist) {
203  if (boost::num_edges(g)==0) return 0;
204  IMP_LOG_VERBOSE("Going to cluster a graph of:"
205  <<boost::num_vertices(g)<<std::endl);
206  float max_dist2=max_dist*max_dist;
207  //get all of the edge weights
208  boost::property_map<Graph, boost::edge_weight_t>::type
209  weight = get(boost::edge_weight, g);
210  std::vector<std::pair<RCEdge,float> > edge_weight;
211  RCEdgeIt ei, ei_end;
212  for(boost::tie(ei,ei_end) = boost::edges(g); ei != ei_end; ++ei){
213  edge_weight.push_back(std::pair<RCEdge,float>(*ei,
214  boost::get(weight,*ei)));
215  }
216  int num_joins=0;
217  //sort the edges by weight
218  std::sort(edge_weight.begin(),edge_weight.end(),sort_by_weight());
219  //sort the edges
220  std::vector<bool> used;
221  used.insert(used.end(),boost::num_vertices(g),false);
222  for(unsigned int i=0;i<edge_weight.size();i++) {
223  RCEdge e = edge_weight[i].first;
224  int v1_ind=boost::source(e,g);
225  int v2_ind=boost::target(e,g);
226  IMP_LOG_VERBOSE("Working on edge "<<i<<"bewteen nodes"<<v1_ind<<
227  " and "<<v2_ind<<std::endl);
228  //check if any end of the edge is deleted
229  if (!used[v1_ind] && !used[v2_ind] &&
230  (edge_weight[i].second < max_dist2)){
231  ++num_joins;
232  used[v1_ind] = true;
233  used[v2_ind] = true;
234 
235  TransformationRecord* rec1 = recs[v1_ind];
236  TransformationRecord* rec2 = recs[v2_ind];
237  if (!(rec1->get_valid() &&rec2->get_valid())) continue;
238  if (rec1->get_score() > rec2->get_score()) {
239  rec1->join_into(*rec2);
240  rec2->set_valid(false);
241  } else {
242  rec2->join_into(*rec1);
243  rec1->set_valid(false);
244  }
245  }
246  } // edges
247  return num_joins;
248 }
249 template<class TransT>
250 void RMSDClustering<TransT>::prepare(const ParticlesTemp& ps) {
251  rmsd_calc_=atom::RMSDCalculator(ps);
252  // save centroid
253  centroid_ = algebra::Vector3D(0,0,0);
254  core::XYZs xyzs(ps);
255  for (core::XYZs::iterator it = xyzs.begin(); it != xyzs.end(); it++) {
256  centroid_ += it->get_coordinates();
257  }
258  centroid_ /= ps.size();
259  is_ready_=true;
260 }
261 
262 template<class TransT>
263 int RMSDClustering<TransT>::fast_clustering(float max_dist,
264  std::vector<TransformationRecord*>& recs) {
265  IMP_LOG_VERBOSE("start fast clustering with "<<recs.size()<<" records\n");
266  int num_joins = 0;
267  boost::scoped_array<bool> used(new bool[recs.size()]);
268  Hash3 g_hash((double)(bin_size_));
269 
270  //load the hash
271  for (int i = 0 ; i < (int)recs.size() ; ++i){
272  used[i] = false;
273  TransT tr=recs[i]->get_record();
275  tr.get_representative_transformation();
276  algebra::Vector3D trans_cen = t.get_transformed(centroid_);
277  g_hash.add(trans_cen, i);
278  IMP_LOG_VERBOSE("add to hash vertex number:"<<i
279  <<" with center:"<<trans_cen<<std::endl);
280  }
281  //work on each bucket
282  const Hash3::GeomMap &M = g_hash.Map();
283  for (Hash3::GeomMap::const_iterator bucket = M.begin();
284  bucket != M.end() ; ++bucket){
285  const Hash3::PointList &pb = bucket->second;
286  IMP_LOG_VERBOSE("Bucket size:"<<pb.size()<<"\n");
287  // if (pb.size()<2) continue;
288  Graph g;
289  std::vector<std::pair<RCEdge,float> > edge_weight;
290  build_graph(pb,recs,max_dist,g);
291  IMP_LOG_VERBOSE("create graph with:"<<boost::num_vertices(g)<<" nodes and"<<
292  boost::num_edges(g)<<" edges out of "<<pb.size()<<" points\n");
293  //cluster all transformations in the bin
294  num_joins +=cluster_graph(g,recs,max_dist);
295  IMP_LOG_VERBOSE("after clustering number of joins::"<<num_joins<<std::endl);
296  }
297  return num_joins;
298 }
299 
300 
301 template<class TransT>
302 int RMSDClustering<TransT>::exhaustive_clustering(float max_dist,
303  std::vector<TransformationRecord *>& recs) {
304  IMP_LOG_VERBOSE("start full clustering with "<< recs.size()<<" records \n");
305  if (recs.size()<2) return 0;
306  boost::scoped_array<bool> used(new bool[recs.size()]);
307  Hash3 ghash((double)(max_dist));
308 
309  //load the hash
310  for (int i = 0 ; i < (int)recs.size() ; ++i) {
311  used[i] = false;
312  algebra::Transformation3D t =
313  recs[i]->get_record().get_representative_transformation();
314  ghash.add(t.get_transformed(centroid_), i);
315  }
316  //build the graph
317  Graph g;
318  build_full_graph(ghash,recs,max_dist,g);
319  int num_joins = cluster_graph(g,recs,max_dist);
320  return num_joins;
321 }
322 template<class TransT>
324  std::vector<TransformationRecord*>*& records) {
325  std::vector<TransformationRecord*> *results =
326  new std::vector<TransformationRecord*>();
327  for (int i = 0 ; i < (int)records->size() ; i++){
328  if ((*records)[i]->get_valid()) {
329  results->push_back((*records)[i]);
330  } else {
331  delete((*records)[i]);
332  }
333  }
334  records->clear();
335  delete records;
336  records = results;
337 }
338 template<class TransT>
340  const std::vector<TransT> &input_trans,
341  std::vector<TransT> & output) {
342  //create initial vectors of transformation records and bit vector to
343  //indicate what is deleted
344  std::vector<TransformationRecord*>* records =
345  new std::vector<TransformationRecord*>();
346  for (typename std::vector<TransT>::const_iterator
347  it = input_trans.begin();it != input_trans.end() ; ++it){
348  TransformationRecord* record = new TransformationRecord(*it);
349  record->set_centroid(centroid_);
350  records->push_back(record);
351  }
352  //fast clustering using geometric hashing
353  while (fast_clustering(max_dist, *records)){
354  clean(records);
355  }
356  clean(records);
357  //complete full clustering
358  while (exhaustive_clustering(max_dist, *records)){
359  clean(records);
360  }
361  // clean(records);
362  //build the vector for output
363  IMP_LOG_VERBOSE("build output of "<<records->size()<<" records \n");
364  for (int i = 0 ; i < (int)records->size() ; ++i){
365  output.push_back((*records)[i]->get_record());
366  delete((*records)[i]);
367  }
368  delete(records);
369  IMP_LOG_VERBOSE("returning "<< output.size()<<" records \n");
370 }
371 
372 IMPMULTIFIT_END_NAMESPACE
373 #endif /* IMPMULTIFIT_RMSD_CLUSTERING_H */